├── .gitignore ├── .idea ├── vcs.xml ├── .gitignore ├── modules.xml ├── gomatrixserverlib.iml └── watcherTasks.xml ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── tests.yml ├── README.md ├── .golangci.yml ├── create.go ├── spec ├── timestamp.go ├── roomtypes.go ├── rawjson.go ├── servername.go ├── roomid.go ├── roomid_test.go ├── senderid_test.go ├── senderid.go ├── userid.go ├── userid_test.go ├── base64.go ├── eventtypes.go ├── matrixerror_test.go └── base64_test.go ├── linter.json ├── travis.sh ├── sendtodevice.go ├── device_update.go ├── event_examples_test.go ├── fclient ├── relaytypes.go ├── invitev3.go ├── invitev2.go ├── invitev2_test.go ├── dnscache_test.go ├── invitev3_test.go ├── crosssigning.go ├── well_known.go ├── dnscache.go ├── crosssigning_test.go ├── federationtypes_test.go └── resolve.go ├── edu.go ├── eventV3_test.go ├── errors.go ├── go.mod ├── hex_string.go ├── transaction.go ├── tokens ├── tokens_test.go ├── tokens_handlers_test.go ├── tokens_handlers.go └── tokens.go ├── headeredevent_test.go ├── hex_string_test.go ├── authchain.go ├── join.go ├── eventversion_test.go ├── invite_test.go ├── event.go ├── handleleave.go ├── stateresolutionv2heaps.go ├── eventV2_test.go ├── signing.go ├── pdu.go ├── eventV1_test.go ├── load.go ├── eventV3.go ├── backfill.go ├── keys.go ├── signing_test.go ├── redactevent_test.go └── go.sum /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # .gitignore for .idea 2 | # 3 | # per https://intellij-support.jetbrains.com/hc/en-us/articles/206544839-How-to-manage-projects-under-Version-Control-Systems 4 | 5 | /tasks.xml 6 | /workspace.xml 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Pull Request Checklist 2 | 3 | * [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off) 4 | 5 | Signed-off-by: `Your Name ` 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gomatrixserverlib 2 | ================= 3 | 4 | [![GoDoc](https://godoc.org/github.com/matrix-org/gomatrixserverlib?status.svg)](https://godoc.org/github.com/matrix-org/gomatrixserverlib) 5 | 6 | Go library for common functions needed by matrix servers. This library assumes Go 1.22+. 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/gomatrixserverlib.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | timeout: 5m 3 | linters: 4 | enable: 5 | - typecheck 6 | - gocyclo 7 | - ineffassign 8 | # - gosec - complains about weak cryptographic primitive sha1 and TLS InsecureSkipVerify set true in getTransport 9 | - misspell 10 | - unparam 11 | - goimports 12 | # - goconst 13 | - unconvert 14 | - errcheck 15 | # - testify - not available in golangci-lint 16 | -------------------------------------------------------------------------------- /create.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | // FledglingEvent is a helper representation of an event used when creating many events in succession. 4 | type FledglingEvent struct { 5 | // The type of the event. 6 | Type string `json:"type"` 7 | // The state_key of the event if the event is a state event or nil if the event is not a state event. 8 | StateKey string `json:"state_key"` 9 | // The JSON object for "content" key of the event. 10 | Content interface{} `json:"content"` 11 | } 12 | -------------------------------------------------------------------------------- /spec/timestamp.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // A Timestamp is a millisecond posix timestamp. 8 | type Timestamp uint64 9 | 10 | // AsTimestamp turns a time.Time into a millisecond posix timestamp. 11 | func AsTimestamp(t time.Time) Timestamp { 12 | return Timestamp(t.UnixMilli()) 13 | } 14 | 15 | // Time turns a millisecond posix timestamp into a UTC time.Time 16 | func (t Timestamp) Time() time.Time { 17 | return time.Unix(int64(t)/1000, (int64(t)%1000)*1000000).UTC() 18 | } 19 | -------------------------------------------------------------------------------- /linter.json: -------------------------------------------------------------------------------- 1 | { 2 | "Deadline": "5m", 3 | "Enable": [ 4 | "vet", 5 | "vetshadow", 6 | "gotype", 7 | "deadcode", 8 | "gocyclo", 9 | "golint", 10 | "varcheck", 11 | "structcheck", 12 | "maligned", 13 | "ineffassign", 14 | "gosec", 15 | "misspell", 16 | "unparam", 17 | "goimports", 18 | "goconst", 19 | "unconvert", 20 | "errcheck", 21 | "interfacer", 22 | "testify" 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /travis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -eux 4 | 5 | cd `dirname $0` 6 | 7 | # -u so that if this is run on a dev box, we get the latest deps, as 8 | # we do on travis. 9 | 10 | go get -u \ 11 | github.com/client9/misspell/cmd/misspell \ 12 | golang.org/x/crypto/ed25519 \ 13 | github.com/matrix-org/util \ 14 | github.com/matrix-org/gomatrix \ 15 | github.com/tidwall/gjson \ 16 | github.com/tidwall/sjson \ 17 | github.com/pkg/errors \ 18 | gopkg.in/yaml.v2 \ 19 | gopkg.in/macaroon.v2 \ 20 | 21 | ./hooks/pre-commit 22 | -------------------------------------------------------------------------------- /sendtodevice.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import "encoding/json" 4 | 5 | type SendToDeviceEvent struct { 6 | Sender string `json:"sender"` 7 | Type string `json:"type"` 8 | Content json.RawMessage `json:"content"` 9 | } 10 | 11 | type ToDeviceMessage struct { 12 | Sender string `json:"sender"` 13 | Type string `json:"type"` 14 | MessageID string `json:"message_id"` 15 | Messages map[string]map[string]json.RawMessage `json:"messages"` 16 | } 17 | -------------------------------------------------------------------------------- /device_update.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import "encoding/json" 4 | 5 | // DeviceListUpdateEvent is https://matrix.org/docs/spec/server_server/latest#m-device-list-update-schema 6 | type DeviceListUpdateEvent struct { 7 | UserID string `json:"user_id"` 8 | DeviceID string `json:"device_id"` 9 | DeviceDisplayName string `json:"device_display_name,omitempty"` 10 | StreamID int64 `json:"stream_id"` 11 | PrevID []int64 `json:"prev_id,omitempty"` 12 | Deleted bool `json:"deleted,omitempty"` 13 | Keys json.RawMessage `json:"keys,omitempty"` 14 | } 15 | -------------------------------------------------------------------------------- /spec/roomtypes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The Matrix.org Foundation C.I.C. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package spec 16 | 17 | const ( 18 | // MSpace https://spec.matrix.org/v1.7/client-server-api/#types 19 | MSpace = "m.space" 20 | ) 21 | -------------------------------------------------------------------------------- /event_examples_test.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import "fmt" 19 | 20 | func ExampleSplitID() { 21 | localpart, domain, err := SplitID('@', "@alice:localhost:8080") 22 | if err != nil { 23 | panic(err) 24 | } 25 | fmt.Println(localpart, domain) 26 | // Output: alice localhost:8080 27 | } 28 | -------------------------------------------------------------------------------- /fclient/relaytypes.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/matrix-org/gomatrixserverlib" 7 | ) 8 | 9 | // A RelayEntry is used to track the nid of an event received from a relay server. 10 | // It is used as the request body of a GET to /_matrix/federation/v1/relay_txn/{userID} 11 | type RelayEntry struct { 12 | EntryID int64 `json:"entry_id"` 13 | } 14 | 15 | // A RespGetRelayTransaction is the response body of a successful GET to /_matrix/federation/v1/relay_txn/{userID} 16 | type RespGetRelayTransaction struct { 17 | Transaction gomatrixserverlib.Transaction `json:"transaction"` 18 | EntryID int64 `json:"entry_id,omitempty"` 19 | EntriesQueued bool `json:"entries_queued"` 20 | } 21 | 22 | // RelayEvents is the request body of a PUT to /_matrix/federation/v1/send_relay/{txnID}/{userID} 23 | type RelayEvents struct { 24 | PDUs []json.RawMessage `json:"pdus"` 25 | EDUs []gomatrixserverlib.EDU `json:"edus"` 26 | } 27 | -------------------------------------------------------------------------------- /spec/rawjson.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // TODO: Remove. Since Go 1.8 this has been fixed. 4 | // RawJSON is a reimplementation of json.RawMessage that supports being used as a value type 5 | // 6 | // For example: 7 | // 8 | // jsonBytes, _ := json.Marshal(struct{ 9 | // RawMessage json.RawMessage 10 | // RawJSON RawJSON 11 | // }{ 12 | // json.RawMessage(`"Hello"`), 13 | // RawJSON(`"World"`), 14 | // }) 15 | // 16 | // Results in: 17 | // 18 | // {"RawMessage":"IkhlbGxvIg==","RawJSON":"World"} 19 | // 20 | // See https://play.golang.org/p/FzhKIJP8-I for a full example. 21 | type RawJSON []byte 22 | 23 | // MarshalJSON implements the json.Marshaller interface using a value receiver. 24 | // This means that RawJSON used as an embedded value will still encode correctly. 25 | func (r RawJSON) MarshalJSON() ([]byte, error) { 26 | return []byte(r), nil 27 | } 28 | 29 | // UnmarshalJSON implements the json.Unmarshaller interface using a pointer receiver. 30 | func (r *RawJSON) UnmarshalJSON(data []byte) error { 31 | *r = RawJSON(data) 32 | return nil 33 | } 34 | -------------------------------------------------------------------------------- /.idea/watcherTasks.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 16 | 28 | 29 | -------------------------------------------------------------------------------- /edu.go: -------------------------------------------------------------------------------- 1 | /* Licensed under the Apache License, Version 2.0 (the "License"); 2 | * you may not use this file except in compliance with the License. 3 | * You may obtain a copy of the License at 4 | * 5 | * http://www.apache.org/licenses/LICENSE-2.0 6 | * 7 | * Unless required by applicable law or agreed to in writing, software 8 | * distributed under the License is distributed on an "AS IS" BASIS, 9 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | * See the License for the specific language governing permissions and 11 | * limitations under the License. 12 | */ 13 | 14 | package gomatrixserverlib 15 | 16 | import ( 17 | "unsafe" 18 | 19 | "github.com/matrix-org/gomatrixserverlib/spec" 20 | ) 21 | 22 | // EDU represents a EDU received via federation 23 | // https://matrix.org/docs/spec/server_server/unstable.html#edus 24 | type EDU struct { 25 | Type string `json:"edu_type"` 26 | Origin string `json:"origin"` 27 | Destination string `json:"destination,omitempty"` 28 | Content spec.RawJSON `json:"content,omitempty"` 29 | } 30 | 31 | func (e *EDU) CacheCost() int { 32 | return int(unsafe.Sizeof(*e)) + 33 | len(e.Type) + 34 | len(e.Origin) + 35 | len(e.Destination) + 36 | cap(e.Content) 37 | } 38 | -------------------------------------------------------------------------------- /eventV3_test.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | "github.com/stretchr/testify/assert" 9 | "golang.org/x/crypto/ed25519" 10 | ) 11 | 12 | func TestEventCreationV3(t *testing.T) { 13 | _, sk, err := ed25519.GenerateKey(nil) 14 | assert.NoError(t, err) 15 | verImpl := MustGetRoomVersion(RoomVersionV12) 16 | sender := "@alice:example.com" 17 | 18 | // Ensure we can make create events 19 | ev, err := verImpl.NewEventBuilderFromProtoEvent(&ProtoEvent{ 20 | Type: spec.MRoomCreate, 21 | StateKey: &emptyStateKey, 22 | Content: []byte(`{"room_version":"12"}`), 23 | SenderID: sender, 24 | Depth: 1, 25 | }).Build(time.Now(), "localhost", "ed25519:1", sk) 26 | assert.NoError(t, err, "failed to build create event") 27 | assert.Equal(t, ev.EventID()[1:], ev.RoomID().String()[1:], "create event ID must equal the room ID") 28 | 29 | // ..and use the new room ID to make other events 30 | _, err = verImpl.NewEventBuilderFromProtoEvent(&ProtoEvent{ 31 | Type: spec.MRoomMember, 32 | StateKey: &sender, 33 | Content: []byte(`{"membership":"join"}`), 34 | SenderID: sender, 35 | Depth: 1, 36 | RoomID: ev.RoomID().String(), 37 | }).Build(time.Now(), "localhost", "ed25519:1", sk) 38 | assert.NoError(t, err, "failed to build member event") 39 | } 40 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/matrix-org/gomatrixserverlib/spec" 7 | ) 8 | 9 | // MissingAuthEventError refers to a situation where one of the auth 10 | // event for a given event was not found. 11 | type MissingAuthEventError struct { 12 | AuthEventID string 13 | ForEventID string 14 | } 15 | 16 | func (e MissingAuthEventError) Error() string { 17 | return fmt.Sprintf( 18 | "gomatrixserverlib: missing auth event with ID %s for event %s", 19 | e.AuthEventID, e.ForEventID, 20 | ) 21 | } 22 | 23 | type BadJSONError struct { 24 | err error 25 | } 26 | 27 | func (e BadJSONError) Error() string { 28 | return fmt.Sprintf("gomatrixserverlib: bad JSON: %s", e.err.Error()) 29 | } 30 | 31 | func (e BadJSONError) Unwrap() error { 32 | return e.err 33 | } 34 | 35 | // FederationError contains context surrounding why a federation request may have failed. 36 | type FederationError struct { 37 | ServerName spec.ServerName // The server being contacted. 38 | Transient bool // Whether the failure is permanent (will fail if performed again) or not. 39 | Reachable bool // Whether the server could be contacted. 40 | Err error // The underlying error message. 41 | } 42 | 43 | func (e FederationError) Error() string { 44 | return fmt.Sprintf("FederationError(t=%v, r=%v): %s", e.Transient, e.Reachable, e.Err.Error()) 45 | } 46 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/matrix-org/gomatrixserverlib 2 | 3 | require ( 4 | github.com/google/go-cmp v0.7.0 5 | github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 6 | github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 7 | github.com/miekg/dns v1.1.66 8 | github.com/sirupsen/logrus v1.9.3 9 | github.com/stretchr/testify v1.10.0 10 | github.com/tidwall/gjson v1.18.0 11 | github.com/tidwall/sjson v1.2.5 12 | golang.org/x/crypto v0.38.0 13 | gopkg.in/h2non/gock.v1 v1.1.2 14 | gopkg.in/macaroon.v2 v2.1.0 15 | gopkg.in/yaml.v2 v2.4.0 16 | ) 17 | 18 | require ( 19 | github.com/davecgh/go-spew v1.1.1 // indirect 20 | github.com/frankban/quicktest v1.14.6 // indirect 21 | github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect 22 | github.com/hashicorp/go-set/v3 v3.0.0 // indirect 23 | github.com/oleiade/lane/v2 v2.0.0 // indirect 24 | github.com/pmezard/go-difflib v1.0.0 // indirect 25 | github.com/tidwall/match v1.1.1 // indirect 26 | github.com/tidwall/pretty v1.2.1 // indirect 27 | golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect 28 | golang.org/x/mod v0.24.0 // indirect 29 | golang.org/x/net v0.40.0 // indirect 30 | golang.org/x/sync v0.14.0 // indirect 31 | golang.org/x/sys v0.33.0 // indirect 32 | golang.org/x/tools v0.33.0 // indirect 33 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 34 | gopkg.in/yaml.v3 v3.0.1 // indirect 35 | ) 36 | 37 | go 1.23.0 38 | 39 | toolchain go1.24.3 40 | -------------------------------------------------------------------------------- /hex_string.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 New Vector Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import ( 19 | "encoding/hex" 20 | "encoding/json" 21 | ) 22 | 23 | // A HexString is a string of bytes that are hex encoded when used in JSON. 24 | // The bytes encoded using hex when marshalled as JSON. 25 | // When the bytes are unmarshalled from JSON they are decoded from hex. 26 | type HexString []byte 27 | 28 | // MarshalJSON encodes the bytes as hex and then encodes the hex as a JSON string. 29 | // This takes a value receiver so that maps and slices of HexString encode correctly. 30 | func (h HexString) MarshalJSON() ([]byte, error) { 31 | return json.Marshal(hex.EncodeToString(h)) 32 | } 33 | 34 | // UnmarshalJSON decodes a JSON string and then decodes the resulting hex. 35 | // This takes a pointer receiver because it needs to write the result of decoding. 36 | func (h *HexString) UnmarshalJSON(raw []byte) (err error) { 37 | var str string 38 | if err = json.Unmarshal(raw, &str); err != nil { 39 | return 40 | } 41 | 42 | *h, err = hex.DecodeString(str) 43 | return 44 | } 45 | -------------------------------------------------------------------------------- /transaction.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/matrix-org/gomatrixserverlib/spec" 7 | ) 8 | 9 | // A Transaction is used to push data from one matrix server to another matrix 10 | // server. 11 | type Transaction struct { 12 | // The ID of the transaction. 13 | TransactionID TransactionID `json:"-"` 14 | // The server that sent the transaction. 15 | Origin spec.ServerName `json:"origin"` 16 | // The server that should receive the transaction. 17 | Destination spec.ServerName `json:"-"` 18 | // The millisecond posix timestamp on the origin server when the 19 | // transaction was created. 20 | OriginServerTS spec.Timestamp `json:"origin_server_ts"` 21 | // The IDs of the most recent transactions sent by the origin server to 22 | // the destination server. Multiple transactions can be sent by the origin 23 | // server to the destination server in parallel so there may be more than 24 | // one previous transaction. 25 | PreviousIDs []TransactionID `json:"-"` 26 | // The room events pushed from the origin server to the destination server 27 | // by this transaction. The events should either be events that originate 28 | // on the origin server or be join m.room.member events. 29 | PDUs []json.RawMessage `json:"pdus"` 30 | // The ephemeral events pushed from origin server to destination server 31 | // by this transaction. The events must orginate at the origin server. 32 | EDUs []EDU `json:"edus,omitempty"` 33 | } 34 | 35 | // A TransactionID identifies a transaction sent by a matrix server to another 36 | // matrix server. The ID must be unique amongst the transactions sent from the 37 | // origin server to the destination, but doesn't have to be globally unique. 38 | // The ID must be safe to insert into a URL path segment. The ID should have a 39 | // format matching '^[0-9A-Za-z\-_]*$' 40 | type TransactionID string 41 | -------------------------------------------------------------------------------- /fclient/invitev3.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/matrix-org/gomatrixserverlib" 7 | ) 8 | 9 | func NewInviteV3Request(event gomatrixserverlib.ProtoEvent, version gomatrixserverlib.RoomVersion, state []gomatrixserverlib.InviteStrippedState) ( 10 | request InviteV3Request, err error, 11 | ) { 12 | if !gomatrixserverlib.KnownRoomVersion(version) { 13 | err = gomatrixserverlib.UnsupportedRoomVersionError{ 14 | Version: version, 15 | } 16 | return 17 | } 18 | request.fields.inviteV2RequestHeaders = inviteV2RequestHeaders{ 19 | RoomVersion: version, 20 | InviteRoomState: state, 21 | } 22 | request.fields.Event = event 23 | return 24 | } 25 | 26 | // InviteV3Request is used in the body of a /_matrix/federation/v3/invite request. 27 | type InviteV3Request struct { 28 | fields struct { 29 | inviteV2RequestHeaders 30 | Event gomatrixserverlib.ProtoEvent `json:"event"` 31 | } 32 | } 33 | 34 | // MarshalJSON implements json.Marshaller 35 | func (i InviteV3Request) MarshalJSON() ([]byte, error) { 36 | return json.Marshal(i.fields) 37 | } 38 | 39 | // UnmarshalJSON implements json.Unmarshaller 40 | func (i *InviteV3Request) UnmarshalJSON(data []byte) error { 41 | err := json.Unmarshal(data, &i.fields) 42 | if err != nil { 43 | return err 44 | } 45 | return err 46 | } 47 | 48 | // Event returns the invite event. 49 | func (i *InviteV3Request) Event() gomatrixserverlib.ProtoEvent { 50 | return i.fields.Event 51 | } 52 | 53 | // RoomVersion returns the room version of the invited room. 54 | func (i *InviteV3Request) RoomVersion() gomatrixserverlib.RoomVersion { 55 | return i.fields.RoomVersion 56 | } 57 | 58 | // InviteRoomState returns stripped state events for the room, containing 59 | // enough information for the client to identify the room. 60 | func (i *InviteV3Request) InviteRoomState() []gomatrixserverlib.InviteStrippedState { 61 | return i.fields.InviteRoomState 62 | } 63 | -------------------------------------------------------------------------------- /tokens/tokens_test.go: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | package tokens 14 | 15 | import ( 16 | "testing" 17 | ) 18 | 19 | var ( 20 | validTokenOp = TokenOptions{ 21 | ServerPrivateKey: []byte("aSecretKey"), 22 | ServerName: "aRandomServerName", 23 | UserID: "aRandomUserID", 24 | } 25 | invalidTokenOps = map[string]TokenOptions{ 26 | "ServerPrivateKey": { 27 | ServerName: "aRandomServerName", 28 | UserID: "aRandomUserID", 29 | }, 30 | "ServerName": { 31 | ServerPrivateKey: []byte("aSecretKey"), 32 | UserID: "aRandomUserID", 33 | }, 34 | "UserID": { 35 | ServerPrivateKey: []byte("aSecretKey"), 36 | ServerName: "aRandomServerName", 37 | }, 38 | } 39 | ) 40 | 41 | func TestGenerateLoginToken(t *testing.T) { 42 | // Test valid 43 | _, err := GenerateLoginToken(validTokenOp) 44 | if err != nil { 45 | t.Errorf("Token generation failed for valid TokenOptions with err: %s", err.Error()) 46 | } 47 | 48 | // Test invalids 49 | for missing, invalidTokenOp := range invalidTokenOps { 50 | _, err := GenerateLoginToken(invalidTokenOp) 51 | if err == nil { 52 | t.Errorf("Token generation should fail for TokenOptions with missing %s", missing) 53 | } 54 | } 55 | } 56 | 57 | func serializationTestError(err error) string { 58 | return "Token Serialization test failed with err: " + err.Error() 59 | } 60 | 61 | func TestSerialization(t *testing.T) { 62 | fakeToken, err := GenerateLoginToken(validTokenOp) 63 | if err != nil { 64 | t.Errorf("%s", serializationTestError(err)) 65 | } 66 | 67 | fakeMacaroon, err := deSerializeMacaroon(fakeToken) 68 | if err != nil { 69 | t.Errorf("%s", serializationTestError(err)) 70 | } 71 | 72 | sameFakeToken, err := serializeMacaroon(fakeMacaroon) 73 | if err != nil { 74 | t.Errorf("%s", serializationTestError(err)) 75 | } 76 | 77 | if sameFakeToken != fakeToken { 78 | t.Errorf("Token Serialization mismatch") 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /headeredevent_test.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | const TestHeaderedExampleEvent = `{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name","_room_version":"1","_event_id":"$yvN1b43rlmcOs5fY:localhost"}` 8 | 9 | func TestUnmarshalMarshalHeaderedEvent(t *testing.T) { 10 | output, err := NewEventFromHeaderedJSON([]byte(TestHeaderedExampleEvent), false) 11 | if err != nil { 12 | t.Fatal(err) 13 | } 14 | j, err := output.ToHeaderedJSON() 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | if string(j) != TestHeaderedExampleEvent { 19 | t.Logf("got: %s", string(j)) 20 | t.Logf("expected: %s", TestHeaderedExampleEvent) 21 | t.Fatalf("round-trip unmarshal and marshal produced different results") 22 | } 23 | } 24 | 25 | func TestUnmarshalHeaderedV4AndVerifyEventID(t *testing.T) { 26 | initialEventJSON := `{"_room_version":"4","_event_id":"$RrGxF28UrHLmoASHndYb9Jb_1SFww2ptmtur9INS438","auth_events":[],"prev_events":[],"type":"m.room.create","room_id":"!uXDCzlYgCTHtiWCkEx:jki.re","sender":"@erikj:jki.re","content":{"room_version":"5","predecessor":{"room_id":"!gdRMqOrTFdOCYHNwOo:half-shot.uk","event_id":"$LP7ROBc4b+cMc1UE9haIz8q5AK2AIW4eJ90FfKLvyZI"},"creator":"@erikj:jki.re"},"depth":1,"prev_state":[],"state_key":"","origin":"jki.re","origin_server_ts":1560284621137,"hashes":{"sha256":"IX6zuNiJpJPNf70BLleL3HSCpjKeq9Uhu7uUpyDjBmc"},"signatures":{"jki.re":{"ed25519:auto":"O4IyFfF2PPtGp5uaDm8t57dZbdh8vc8Q64LgCwvzYRVItAMI0uisfiAFaxkVT7MRpzh6N2QNN5NMRXZKmgPYDA"}},"unsigned":{"age":1321650}}` 27 | expectedEventID := "$RrGxF28UrHLmoASHndYb9Jb_1SFww2ptmtur9INS438" 28 | event, err := NewEventFromHeaderedJSON([]byte(initialEventJSON), false) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | if event.EventID() != expectedEventID { 34 | t.Fatalf("event ID '%s' does not match expected '%s'", event.EventID(), expectedEventID) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /spec/servername.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "net" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // A ServerName is the name a matrix homeserver is identified by. 10 | // It is a DNS name or IP address optionally followed by a port. 11 | // 12 | // https://matrix.org/docs/spec/appendices.html#server-name 13 | type ServerName string 14 | 15 | // ParseAndValidateServerName splits a ServerName into a host and port part, 16 | // and checks that it is a valid server name according to the spec. 17 | // 18 | // if there is no explicit port, returns '-1' as the port. 19 | func ParseAndValidateServerName(serverName ServerName) (host string, port int, valid bool) { 20 | // Don't go any further if the server name is an empty string. 21 | if len(serverName) == 0 { 22 | return 23 | } 24 | 25 | host, port = splitServerName(serverName) 26 | 27 | // the host part must be one of: 28 | // - a valid (ascii) dns name 29 | // - an IPv4 address 30 | // - an IPv6 address 31 | 32 | if len(host) == 0 { 33 | return 34 | } 35 | 36 | if host[0] == '[' { 37 | // must be a valid IPv6 address 38 | if host[len(host)-1] != ']' { 39 | return 40 | } 41 | ip := host[1 : len(host)-1] 42 | if net.ParseIP(ip) == nil { 43 | return 44 | } 45 | valid = true 46 | return 47 | } 48 | 49 | // try parsing as an IPv4 address 50 | ip := net.ParseIP(host) 51 | if ip != nil && ip.To4() != nil { 52 | valid = true 53 | return 54 | } 55 | 56 | // must be a valid DNS Name 57 | for _, r := range host { 58 | if !isDNSNameChar(r) { 59 | return 60 | } 61 | } 62 | 63 | valid = true 64 | return 65 | } 66 | 67 | func isDNSNameChar(r rune) bool { 68 | if r >= 'A' && r <= 'Z' { 69 | return true 70 | } 71 | if r >= 'a' && r <= 'z' { 72 | return true 73 | } 74 | if r >= '0' && r <= '9' { 75 | return true 76 | } 77 | if r == '-' || r == '.' { 78 | return true 79 | } 80 | return false 81 | } 82 | 83 | // splitServerName splits a ServerName into host and port, without doing 84 | // any validation. 85 | // 86 | // if there is no explicit port, returns '-1' as the port 87 | func splitServerName(serverName ServerName) (string, int) { 88 | nameStr := string(serverName) 89 | 90 | lastColon := strings.LastIndex(nameStr, ":") 91 | if lastColon < 0 { 92 | // no colon: no port 93 | return nameStr, -1 94 | } 95 | 96 | portStr := nameStr[lastColon+1:] 97 | port, err := strconv.ParseUint(portStr, 10, 16) 98 | if err != nil { 99 | // invalid port (possibly an ipv6 host) 100 | return nameStr, -1 101 | } 102 | 103 | return nameStr[:lastColon], int(port) 104 | } 105 | -------------------------------------------------------------------------------- /hex_string_test.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import ( 19 | "encoding/json" 20 | "testing" 21 | ) 22 | 23 | func TestMarshalHex(t *testing.T) { 24 | input := HexString("this\xffis\xffa\xfftest") 25 | want := `"74686973ff6973ff61ff74657374"` 26 | got, err := json.Marshal(input) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | if string(got) != want { 31 | t.Fatalf("json.Marshal(HexString(%q)): wanted %q got %q", string(input), want, string(got)) 32 | } 33 | } 34 | 35 | func TestUnmarshalHex(t *testing.T) { 36 | input := []byte(`"74686973ff6973ff61ff74657374"`) 37 | want := "this\xffis\xffa\xfftest" 38 | var got HexString 39 | err := json.Unmarshal(input, &got) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | if string(got) != want { 44 | t.Fatalf("json.Unmarshal(%q): wanted %q got %q", string(input), want, string(got)) 45 | } 46 | } 47 | 48 | func TestMarshalHexStruct(t *testing.T) { 49 | input := struct{ Value HexString }{HexString("this\xffis\xffa\xfftest")} 50 | want := `{"Value":"74686973ff6973ff61ff74657374"}` 51 | got, err := json.Marshal(input) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | if string(got) != want { 56 | t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) 57 | } 58 | } 59 | 60 | func TestMarshalHexMap(t *testing.T) { 61 | input := map[string]HexString{"Value": HexString("this\xffis\xffa\xfftest")} 62 | want := `{"Value":"74686973ff6973ff61ff74657374"}` 63 | got, err := json.Marshal(input) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | if string(got) != want { 68 | t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) 69 | } 70 | } 71 | 72 | func TestMarshalHexSlice(t *testing.T) { 73 | input := []HexString{HexString("this\xffis\xffa\xfftest")} 74 | want := `["74686973ff6973ff61ff74657374"]` 75 | got, err := json.Marshal(input) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | if string(got) != want { 80 | t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /fclient/invitev2.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | 7 | "github.com/matrix-org/gomatrixserverlib" 8 | "github.com/tidwall/gjson" 9 | ) 10 | 11 | // InviteV2Request and InviteV2StrippedState are defined in 12 | // https://matrix.org/docs/spec/server_server/r0.1.3#put-matrix-federation-v2-invite-roomid-eventid 13 | 14 | func NewInviteV2Request(event gomatrixserverlib.PDU, state []gomatrixserverlib.InviteStrippedState) ( 15 | request InviteV2Request, err error, 16 | ) { 17 | if !gomatrixserverlib.KnownRoomVersion(event.Version()) { 18 | err = gomatrixserverlib.UnsupportedRoomVersionError{ 19 | Version: event.Version(), 20 | } 21 | return 22 | } 23 | request.fields.inviteV2RequestHeaders = inviteV2RequestHeaders{ 24 | RoomVersion: event.Version(), 25 | InviteRoomState: state, 26 | } 27 | request.fields.Event = event 28 | return 29 | } 30 | 31 | type inviteV2RequestHeaders struct { 32 | RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` 33 | InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"` 34 | } 35 | 36 | // InviteV2Request is used in the body of a /_matrix/federation/v2/invite request. 37 | type InviteV2Request struct { 38 | fields struct { 39 | inviteV2RequestHeaders 40 | Event gomatrixserverlib.PDU `json:"event"` 41 | } 42 | } 43 | 44 | // MarshalJSON implements json.Marshaller 45 | func (i InviteV2Request) MarshalJSON() ([]byte, error) { 46 | return json.Marshal(i.fields) 47 | } 48 | 49 | // UnmarshalJSON implements json.Unmarshaller 50 | func (i *InviteV2Request) UnmarshalJSON(data []byte) error { 51 | err := json.Unmarshal(data, &i.fields.inviteV2RequestHeaders) 52 | if err != nil { 53 | return err 54 | } 55 | eventJSON := gjson.GetBytes(data, "event") 56 | if !eventJSON.Exists() { 57 | return errors.New("gomatrixserverlib: request doesn't contain event") 58 | } 59 | verImpl, err := gomatrixserverlib.GetRoomVersion(i.fields.RoomVersion) 60 | if err != nil { 61 | return err 62 | } 63 | i.fields.Event, err = verImpl.NewEventFromUntrustedJSON([]byte(eventJSON.String())) 64 | return err 65 | } 66 | 67 | // Event returns the invite event. 68 | func (i *InviteV2Request) Event() gomatrixserverlib.PDU { 69 | return i.fields.Event 70 | } 71 | 72 | // RoomVersion returns the room version of the invited room. 73 | func (i *InviteV2Request) RoomVersion() gomatrixserverlib.RoomVersion { 74 | return i.fields.RoomVersion 75 | } 76 | 77 | // InviteRoomState returns stripped state events for the room, containing 78 | // enough information for the client to identify the room. 79 | func (i *InviteV2Request) InviteRoomState() []gomatrixserverlib.InviteStrippedState { 80 | return i.fields.InviteRoomState 81 | } 82 | -------------------------------------------------------------------------------- /spec/roomid.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | const roomSigil = '!' 10 | 11 | var domainlessRoomIDRegexp = regexp.MustCompile(`^[A-Za-z0-9_-]{43}$`) 12 | 13 | // A RoomID identifies a matrix room as per the matrix specification 14 | // https://spec.matrix.org/v1.6/appendices/#room-ids-and-event-ids 15 | type RoomID struct { 16 | raw string 17 | opaqueID string 18 | domain string 19 | isDomainless bool 20 | } 21 | 22 | func NewRoomID(id string) (*RoomID, error) { 23 | return parseAndValidateRoomID(id) 24 | } 25 | 26 | // Returns the full roomID string including leading sigil 27 | func (room RoomID) String() string { 28 | return room.raw 29 | } 30 | 31 | // Returns just the localpart of the roomID 32 | func (room RoomID) OpaqueID() string { 33 | return room.opaqueID 34 | } 35 | 36 | // Returns just the domain of the roomID 37 | func (room RoomID) Domain() ServerName { 38 | if room.isDomainless { 39 | panic("Called RoomID.Domain() on domain-less room ID " + room.String()) 40 | } 41 | return ServerName(room.domain) 42 | } 43 | 44 | func parseAndValidateRoomID(id string) (*RoomID, error) { 45 | idLength := len(id) 46 | if idLength < 4 { // 4 since minimum roomID includes an !, :, non-empty opaque ID, non-empty domain 47 | return nil, fmt.Errorf("length %d is too short to be valid", idLength) 48 | } 49 | if id[0] != roomSigil { 50 | return nil, fmt.Errorf("first character is not '%c'", roomSigil) 51 | } 52 | 53 | hasDomain := strings.ContainsRune(id, localDomainSeparator) 54 | if !hasDomain { 55 | // new form room IDs must be 43 characters of unpadded urlsafe base64, so check that now. 56 | if !domainlessRoomIDRegexp.MatchString(id[1:]) { 57 | return nil, fmt.Errorf("domainless room IDs must consist of 43 unpadded urlsafe base64 characters") 58 | } 59 | return &RoomID{ 60 | raw: id, 61 | opaqueID: id[1:], 62 | domain: "", 63 | isDomainless: true, 64 | }, nil 65 | } 66 | 67 | opaqueID, domain, found := strings.Cut(id[1:], string(localDomainSeparator)) 68 | if !found { 69 | return nil, fmt.Errorf("at least one '%c' is expected in the room id", localDomainSeparator) 70 | } 71 | if _, _, ok := ParseAndValidateServerName(ServerName(domain)); !ok { 72 | return nil, fmt.Errorf("domain is invalid") 73 | } 74 | 75 | // NOTE: There are no character limitations on the opaque part of room ids 76 | opaqueLength := len(opaqueID) 77 | if opaqueLength < 1 { 78 | return nil, fmt.Errorf("opaque id length %d is too short to be valid", opaqueLength) 79 | } 80 | 81 | roomID := &RoomID{ 82 | raw: id, 83 | opaqueID: opaqueID, 84 | domain: domain, 85 | isDomainless: false, 86 | } 87 | return roomID, nil 88 | } 89 | -------------------------------------------------------------------------------- /spec/roomid_test.go: -------------------------------------------------------------------------------- 1 | package spec_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | ) 9 | 10 | func TestValidRoomIDs(t *testing.T) { 11 | tests := map[string]struct { 12 | opaque string 13 | domain string 14 | }{ 15 | "basic": {opaque: defaultLocalpart, domain: defaultDomain}, 16 | "extensive_opaque": {opaque: "!\"#$%&'()*+,-./0123456789;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~", domain: defaultDomain}, 17 | "domain_with_port": {opaque: defaultLocalpart, domain: "domain.org:80"}, 18 | "minimum_id": {opaque: "a", domain: "1"}, 19 | } 20 | 21 | for name, tc := range tests { 22 | t.Run(name, func(t *testing.T) { 23 | raw := fmt.Sprintf("!%s:%s", tc.opaque, tc.domain) 24 | 25 | roomID, err := spec.NewRoomID(raw) 26 | 27 | if err != nil { 28 | t.Fatalf("valid roomID should not fail: %s", err.Error()) 29 | } 30 | if roomID.OpaqueID() != tc.opaque { 31 | t.Fatalf("OpaqueID - Expected: %s Actual: %s ", tc.opaque, roomID.OpaqueID()) 32 | } 33 | if roomID.Domain() != spec.ServerName(tc.domain) { 34 | t.Fatalf("Domain - Expected: %s Actual: %s ", spec.ServerName(tc.domain), roomID.Domain()) 35 | } 36 | if roomID.String() != raw { 37 | t.Fatalf("Raw - Expected: %s Actual: %s ", raw, roomID.String()) 38 | } 39 | }) 40 | } 41 | } 42 | 43 | func TestInvalidRoomIDs(t *testing.T) { 44 | tests := map[string]struct { 45 | rawRoomID string 46 | }{ 47 | "empty": {rawRoomID: ""}, 48 | "no_leading_!": {rawRoomID: "localpart:domain"}, 49 | "no_colon": {rawRoomID: "!localpartdomain"}, 50 | "invalid_domain_chars": {rawRoomID: "!localpart:domain/"}, 51 | "no_local": {rawRoomID: "!:domain"}, 52 | "no_domain": {rawRoomID: "!localpart:"}, 53 | } 54 | 55 | for name, tc := range tests { 56 | t.Run(name, func(t *testing.T) { 57 | _, err := spec.NewRoomID(tc.rawRoomID) 58 | 59 | if err == nil { 60 | t.Fatalf("roomID is not valid, it shouldn't parse") 61 | } 62 | }) 63 | } 64 | } 65 | 66 | func TestSameRoomIDsAreEqual(t *testing.T) { 67 | id := "!localpart:domain" 68 | 69 | roomID, err := spec.NewRoomID(id) 70 | roomID2, err2 := spec.NewRoomID(id) 71 | 72 | if err != nil || err2 != nil { 73 | t.Fatalf("roomID is valid, it should parse") 74 | } 75 | 76 | if *roomID != *roomID2 { 77 | t.Fatalf("roomIDs should be equal") 78 | } 79 | } 80 | 81 | func TestDifferentRoomIDsAreNotEqual(t *testing.T) { 82 | id := "!localpart:domain" 83 | id2 := "!localpart2:domain" 84 | 85 | roomID, err := spec.NewRoomID(id) 86 | roomID2, err2 := spec.NewRoomID(id2) 87 | 88 | if err != nil || err2 != nil { 89 | t.Fatalf("roomID is valid, it should parse") 90 | } 91 | 92 | if *roomID == *roomID2 { 93 | t.Fatalf("roomIDs shouldn't be equal") 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /fclient/invitev2_test.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/matrix-org/gomatrixserverlib" 8 | ) 9 | 10 | const TestInviteV2ExampleEvent = `{"_room_version":"1","auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}` 11 | 12 | func TestMarshalInviteV2Request(t *testing.T) { 13 | expected := `{"room_version":"1","invite_room_state":[],"event":{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}}` 14 | 15 | output, err := gomatrixserverlib.NewEventFromHeaderedJSON([]byte(TestInviteV2ExampleEvent), false) 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | 20 | inviteReq, err := NewInviteV2Request(output, []gomatrixserverlib.InviteStrippedState{}) 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | 25 | j, err := json.Marshal(inviteReq) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | if string(j) != expected { 31 | t.Fatalf("got %q, expected %q", string(j), expected) 32 | } 33 | } 34 | 35 | func TestStrippedState(t *testing.T) { 36 | expected := `{"content":{"name":"test3"},"state_key":"","type":"m.room.name","sender":"@test:localhost"}` 37 | 38 | output, err := gomatrixserverlib.NewEventFromHeaderedJSON([]byte(TestInviteV2ExampleEvent), false) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | 43 | stripped := gomatrixserverlib.NewInviteStrippedState(output) 44 | 45 | j, err := json.Marshal(stripped) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | 50 | if string(j) != expected { 51 | t.Fatalf("got %q, expected %q", string(j), expected) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /fclient/dnscache_test.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | var dnsResolverHits chan string 11 | 12 | func init() { 13 | dnsResolverHits = make(chan string, 1) 14 | } 15 | 16 | type dummyNetResolver struct{} 17 | 18 | func (r *dummyNetResolver) LookupIPAddr(_ context.Context, hostname string) ([]net.IPAddr, error) { 19 | dnsResolverHits <- hostname 20 | return []net.IPAddr{ 21 | { 22 | IP: net.IP("1.2.3.4"), 23 | }, 24 | }, nil 25 | } 26 | 27 | func mustCreateCache(size int, lifetime time.Duration) *DNSCache { 28 | cache := NewDNSCache(size, lifetime, []string{}, []string{}) 29 | cache.resolver = &dummyNetResolver{} 30 | return cache 31 | } 32 | 33 | func TestDNSCache(t *testing.T) { 34 | cache := mustCreateCache(1, time.Second) 35 | ctx := context.Background() 36 | 37 | // STEP 1: First we'll start with first.com. 38 | 39 | // first.com shouldn't be in the cache at this point. 40 | if _, ok := cache.lookup(ctx, "first.com"); ok { 41 | t.Fatalf("shouldn't be in the cache") 42 | } 43 | select { 44 | case hostname := <-dnsResolverHits: 45 | if hostname != "first.com" { 46 | t.Fatalf("expected resolve for first.com, got %q", hostname) 47 | } 48 | default: 49 | t.Fatalf("should have hit the resolver") 50 | } 51 | 52 | // first.com should be in the cache this time. 53 | if _, ok := cache.lookup(ctx, "first.com"); !ok { 54 | t.Fatalf("should be in the cache") 55 | } 56 | select { 57 | case hostname := <-dnsResolverHits: 58 | t.Fatalf("shouldn't have hit the resolver but got a resolve for %q", hostname) 59 | default: 60 | } 61 | 62 | // STEP 2: Then we'll try second.net. Since the cache is only 63 | // one entry big, this should evict first.com. 64 | 65 | // second.net shouldn't be in the cache at this point. 66 | if _, ok := cache.lookup(ctx, "second.net"); ok { 67 | t.Fatalf("shouldn't be in the cache") 68 | } 69 | select { 70 | case hostname := <-dnsResolverHits: 71 | if hostname != "second.net" { 72 | t.Fatalf("expected resolve for second.net, got %q", hostname) 73 | } 74 | default: 75 | t.Fatalf("should have hit the resolver") 76 | } 77 | 78 | // second.net should be in the cache this time. 79 | if _, ok := cache.lookup(ctx, "second.net"); !ok { 80 | t.Fatalf("should be in the cache") 81 | } 82 | select { 83 | case hostname := <-dnsResolverHits: 84 | t.Fatalf("shouldn't have hit the resolver but got a resolve for %q", hostname) 85 | default: 86 | } 87 | 88 | // STEP 3: Now we'll retry first.com, which should have been 89 | // evicted. 90 | 91 | // first.com shouldn't be in the cache at this point. 92 | if _, ok := cache.lookup(ctx, "first.com"); ok { 93 | t.Fatalf("shouldn't be in the cache") 94 | } 95 | select { 96 | case hostname := <-dnsResolverHits: 97 | if hostname != "first.com" { 98 | t.Fatalf("expected resolve for first.com, got %q", hostname) 99 | } 100 | default: 101 | t.Fatalf("should have hit the resolver") 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /spec/senderid_test.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "crypto/ed25519" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestUserIDSenderIDs(t *testing.T) { 10 | tests := map[string]UserID{ 11 | "basic": NewUserIDOrPanic("@localpart:domain", false), 12 | "extensive_local": NewUserIDOrPanic("@abcdefghijklmnopqrstuvwxyz0123456789._=-/:domain", false), 13 | "extensive_local_historic": NewUserIDOrPanic("@!\"#$%&'()*+,-./0123456789;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~:domain", true), 14 | "domain_with_port": NewUserIDOrPanic("@localpart:domain.org:80", false), 15 | "minimum_id": NewUserIDOrPanic("@a:1", false), 16 | } 17 | 18 | for name, userID := range tests { 19 | t.Run(name, func(t *testing.T) { 20 | senderID := SenderIDFromUserID(userID) 21 | 22 | if string(senderID) != userID.String() { 23 | t.Fatalf("Created sender ID did not match user ID string: senderID %s for user ID %s", string(senderID), userID.String()) 24 | } 25 | if !senderID.IsUserID() { 26 | t.Fatalf("IsUserID returned false for user ID: %s", userID.String()) 27 | } 28 | if senderID.IsPseudoID() { 29 | t.Fatalf("IsPseudoID returned true for user ID: %s", userID.String()) 30 | } 31 | returnedUserID := senderID.ToUserID() 32 | if returnedUserID == nil { 33 | t.Fatalf("ToUserID returned nil value") 34 | } 35 | if !reflect.DeepEqual(userID, *returnedUserID) { 36 | t.Fatalf("ToUserID returned different user ID than one used to created sender ID\ncreated with %s\nreturned %s", userID, *returnedUserID) 37 | } 38 | roomKey := senderID.ToPseudoID() 39 | if roomKey != nil { 40 | t.Fatalf("ToPseudoID returned non-nil value for user ID: %s, returned %s", userID.String(), roomKey) 41 | } 42 | }) 43 | } 44 | } 45 | 46 | func TestPseudoIDSenderIDs(t *testing.T) { 47 | // Generate key from all zeroes seed 48 | testKeySeed := make([]byte, 32) 49 | testKey := ed25519.NewKeyFromSeed(testKeySeed) 50 | 51 | t.Run("test pseudo ID", func(t *testing.T) { 52 | senderID := SenderIDFromPseudoIDKey(testKey) 53 | testPubkey := testKey.Public() 54 | expectedSenderIDString := Base64Bytes(testPubkey.(ed25519.PublicKey)).Encode() 55 | 56 | if string(senderID) != expectedSenderIDString { 57 | t.Fatalf("Created sender ID did not match provided key: created sender ID %s, expected: %s", string(senderID), expectedSenderIDString) 58 | } 59 | if !senderID.IsPseudoID() { 60 | t.Fatalf("IsPseudoID returned false for pseudo ID sender ID") 61 | } 62 | if senderID.IsUserID() { 63 | t.Fatalf("IsUserID returned true for pseudo ID sender ID") 64 | } 65 | returnedKey := senderID.ToPseudoID() 66 | if returnedKey == nil { 67 | t.Fatal("ToPseudoID returned nil") 68 | } 69 | if !reflect.DeepEqual(testPubkey, *returnedKey) { 70 | t.Fatalf("ToPseudoID returned different key to the one used to create the sender ID:\ncreated with %v\nreturned %v", testPubkey, *returnedKey) 71 | } 72 | userID := senderID.ToUserID() 73 | if userID != nil { 74 | t.Fatalf("ToUserID returned non-nil value %v", userID.String()) 75 | } 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /tokens/tokens_handlers_test.go: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | package tokens 14 | 15 | import ( 16 | "testing" 17 | ) 18 | 19 | var ( 20 | // If any of these options are missing, validation should fail 21 | invalidMissings = []string{"ServerPrivateKey", "UserID"} 22 | invalidKeyTokenOp = TokenOptions{ 23 | ServerPrivateKey: []byte("notASecretKey"), 24 | UserID: "aRandomUserID", 25 | } 26 | invalidUserTokenOp = TokenOptions{ 27 | ServerPrivateKey: []byte("aSecretKey"), 28 | UserID: "notTheSameUserID", 29 | } 30 | ) 31 | 32 | func expiredValidTokenOp() TokenOptions { 33 | op := validTokenOp 34 | // This will set the expiry to 1 second ago 35 | op.Duration = -1 36 | return op 37 | } 38 | 39 | func TestExpiredLoginToken(t *testing.T) { 40 | fakeToken, err := GenerateLoginToken(expiredValidTokenOp()) 41 | if err != nil { 42 | t.Errorf("Unexpected error from token generation: %v", err) 43 | } 44 | if err = ValidateToken(validTokenOp, fakeToken); err == nil { 45 | t.Error("Token validation should fail for expired token") 46 | } 47 | } 48 | 49 | func TestValidateToken(t *testing.T) { 50 | fakeToken, err := GenerateLoginToken(validTokenOp) 51 | if err != nil { 52 | t.Errorf("Token generation failed for valid TokenOptions with err: %s", err.Error()) 53 | } 54 | 55 | // Test validation 56 | res := ValidateToken(validTokenOp, fakeToken) 57 | if res != nil { 58 | t.Error("Token validation failed with response: ", res) 59 | } 60 | 61 | // Test validation fails for invalid TokenOp 62 | for _, invalidMissing := range invalidMissings { 63 | res = ValidateToken(invalidTokenOps[invalidMissing], fakeToken) 64 | if res == nil { 65 | t.Errorf("Token validation should fail for TokenOptions with missing %s", invalidMissing) 66 | } 67 | } 68 | 69 | for _, invalid := range []TokenOptions{invalidKeyTokenOp, invalidUserTokenOp} { 70 | res = ValidateToken(invalid, fakeToken) 71 | if res == nil { 72 | t.Error("Token validation should fail for invalid TokenOptions: ", invalid) 73 | } 74 | } 75 | } 76 | 77 | func TestGetUserFromToken(t *testing.T) { 78 | fakeToken, err := GenerateLoginToken(validTokenOp) 79 | if err != nil { 80 | t.Errorf("Token generation failed for valid TokenOptions with err: %s", err.Error()) 81 | } 82 | 83 | // Test validation 84 | name, err := GetUserFromToken(fakeToken) 85 | if err != nil { 86 | t.Error("Failed to get userID from Token: ", err) 87 | } 88 | 89 | if name != validTokenOp.UserID { 90 | t.Error("UserID from Token doesn't match, got: ", name, " expected: ", validTokenOp.UserID) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /authchain.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | ) 9 | 10 | // EventProvider returns the requested list of events. 11 | type EventProvider func(roomVer RoomVersion, eventIDs []string) ([]PDU, error) 12 | 13 | // VerifyEventAuthChain will verify that the event is allowed according to its auth_events, and then 14 | // recursively verify each of those auth_events. 15 | // 16 | // This function implements Step 4 of https://matrix.org/docs/spec/server_server/latest#checks-performed-on-receipt-of-a-pdu 17 | // "Passes authorization rules based on the event's auth events, otherwise it is rejected." 18 | // If an event passes this function without error, the caller should make sure that all the auth_events were actually for 19 | // a valid room state, and not referencing random bits of room state from different positions in time (Step 5). 20 | // 21 | // The `provideEvents` function will only be called for *new* events rather than for everything as it is 22 | // assumed that this function is costly. Failing to provide all the requested events will fail this function. 23 | // Returning an error from `provideEvents` will also fail this function. 24 | func VerifyEventAuthChain(ctx context.Context, eventToVerify PDU, provideEvents EventProvider, userIDForSender spec.UserIDForSender) error { 25 | eventsByID := make(map[string]PDU) // A lookup table for verifying this auth chain 26 | evv := eventToVerify 27 | eventsByID[evv.EventID()] = evv 28 | verifiedEvents := make(map[string]bool) // events are put here when they are fully verified. 29 | eventsToVerify := []PDU{evv} 30 | var curr PDU 31 | 32 | for len(eventsToVerify) > 0 { 33 | // pop the top of the stack 34 | // A stack works best here as it means we do depth-first verification which reduces the 35 | // number of duplicate events to verify. 36 | curr, eventsToVerify = eventsToVerify[len(eventsToVerify)-1], eventsToVerify[:len(eventsToVerify)-1] 37 | if verifiedEvents[curr.EventID()] { 38 | continue // already verified 39 | } 40 | // work out which events we need to fetch, if any. 41 | var need []string 42 | for _, needEventID := range curr.AuthEventIDs() { 43 | if eventsByID[needEventID] == nil { 44 | need = append(need, needEventID) 45 | } 46 | } 47 | // fetch the events and add them to the lookup table 48 | if len(need) > 0 { 49 | newEvents, err := provideEvents(eventToVerify.Version(), need) 50 | if err != nil { 51 | return fmt.Errorf("gomatrixserverlib: VerifyEventAuthChain failed to obtain auth events: %w", err) 52 | } 53 | for i := range newEvents { 54 | eventsByID[newEvents[i].EventID()] = newEvents[i] // add to lookup table 55 | } 56 | eventsToVerify = append(eventsToVerify, newEvents...) // verify these events too 57 | } 58 | // verify the event 59 | if err := checkAllowedByAuthEvents(curr, eventsByID, provideEvents, userIDForSender); err != nil { 60 | return fmt.Errorf("gomatrixserverlib: VerifyEventAuthChain %v failed auth check: %w", curr.EventID(), err) 61 | } 62 | // add to the verified list 63 | verifiedEvents[curr.EventID()] = true 64 | } 65 | return nil 66 | } 67 | -------------------------------------------------------------------------------- /spec/senderid.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The Matrix.org Foundation C.I.C. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package spec 16 | 17 | import ( 18 | "context" 19 | 20 | "golang.org/x/crypto/ed25519" 21 | ) 22 | 23 | type SenderID string 24 | 25 | type UserIDForSender func(roomID RoomID, senderID SenderID) (*UserID, error) 26 | type SenderIDForUser func(roomID RoomID, userID UserID) (*SenderID, error) 27 | 28 | // CreateSenderID is a function used to create the pseudoID private key. 29 | type CreateSenderID func(ctx context.Context, userID UserID, roomID RoomID, roomVersion string) (SenderID, ed25519.PrivateKey, error) 30 | 31 | // StoreSenderIDFromPublicID is a function to store the mxid_mapping after receiving a join event over federation. 32 | type StoreSenderIDFromPublicID func(ctx context.Context, senderID SenderID, userID string, id RoomID) error 33 | 34 | // Create a new sender ID from a private room key 35 | func SenderIDFromPseudoIDKey(key ed25519.PrivateKey) SenderID { 36 | return SenderID(Base64Bytes(key.Public().(ed25519.PublicKey)).Encode()) 37 | } 38 | 39 | // Create a new sender ID from a user ID 40 | func SenderIDFromUserID(user UserID) SenderID { 41 | return SenderID(user.String()) 42 | } 43 | 44 | // Decodes this sender ID as base64, i.e. returns the raw bytes of the 45 | // pseudo ID used to create this SenderID, assuming this SenderID was made 46 | // using a pseudo ID. 47 | func (s SenderID) RawBytes() (res Base64Bytes, err error) { 48 | err = res.Decode(string(s)) 49 | if err != nil { 50 | return nil, err 51 | } 52 | return res, nil 53 | } 54 | 55 | // Returns true if this SenderID was made using a user ID 56 | func (s SenderID) IsUserID() bool { 57 | // Key is base64, @ is not a valid base64 char 58 | // So if string starts with @, then this sender ID must 59 | // be a user ID 60 | return string(s)[0] == '@' 61 | } 62 | 63 | // Returns true if this SenderID was made using a pseudo ID 64 | func (s SenderID) IsPseudoID() bool { 65 | return !s.IsUserID() 66 | } 67 | 68 | // Returns the non-nil UserID used to create this SenderID, or nil 69 | // if this SenderID was not created using a UserID 70 | func (s SenderID) ToUserID() *UserID { 71 | if s.IsUserID() { 72 | uID, _ := NewUserID(string(s), true) 73 | return uID 74 | } 75 | 76 | return nil 77 | } 78 | 79 | // Returns the non-nil room public key (pseudo ID) used to create this 80 | // SenderID, or nil if this SenderID was not created using a pseudo ID 81 | func (s SenderID) ToPseudoID() *ed25519.PublicKey { 82 | if s.IsPseudoID() { 83 | decoded, err := s.RawBytes() 84 | if err != nil { 85 | return nil 86 | } 87 | key := ed25519.PublicKey([]byte(decoded)) 88 | return &key 89 | } 90 | 91 | return nil 92 | } 93 | -------------------------------------------------------------------------------- /tokens/tokens_handlers.go: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | package tokens 14 | 15 | import ( 16 | "errors" 17 | "strconv" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | // GetUserFromToken returns the user associated with the token 23 | // Returns the error if something goes wrong. 24 | // Warning: Does not validate the token. Use ValidateToken for that. 25 | func GetUserFromToken(token string) (user string, err error) { 26 | mac, err := deSerializeMacaroon(token) 27 | if err != nil { 28 | return 29 | } 30 | 31 | user = string(mac.Id()[:]) 32 | return 33 | } 34 | 35 | // ValidateToken validates that the token is parseable and signed by this server. 36 | // Returns an error if the token is invalid, otherwise nil. 37 | func ValidateToken(op TokenOptions, token string) error { 38 | mac, err := deSerializeMacaroon(token) 39 | if err != nil { 40 | return errors.New("Token does not represent a valid macaroon") 41 | } 42 | 43 | caveats, err := mac.VerifySignature(op.ServerPrivateKey, nil) 44 | if err != nil { 45 | return errors.New("Provided token was not issued by this server") 46 | } 47 | 48 | err = verifyCaveats(caveats, op.UserID) 49 | if err != nil { 50 | return errors.New("Provided token not authorized") 51 | } 52 | return nil 53 | } 54 | 55 | // verifyCaveats verifies caveats associated with a login token macaroon. 56 | // which are "gen = 1", "user_id = ...", "time < ..." 57 | // Returns nil on successful verification, else returns an error. 58 | func verifyCaveats(caveats []string, userID string) error { 59 | // variable verified represents a bitmap 60 | // last 4 bits are Uvvv where, 61 | // U: unknownCaveat 62 | // v: caveat to be verified 63 | var verified uint8 64 | now := time.Now().Second() 65 | 66 | LoopCaveat: 67 | for _, caveat := range caveats { 68 | switch { 69 | case caveat == Gen: 70 | verified |= 1 71 | case strings.HasPrefix(caveat, UserPrefix): 72 | if caveat[len(UserPrefix):] == userID { 73 | verified |= 2 74 | } 75 | case strings.HasPrefix(caveat, TimePrefix): 76 | if verifyExpiry(caveat[len(TimePrefix):], now) { 77 | verified |= 4 78 | } 79 | default: 80 | verified |= 8 81 | break LoopCaveat 82 | } 83 | } 84 | // Check that all three caveats are verified and no extra caveats 85 | // i.e. Uvvv == 0111 86 | if verified == 7 { 87 | return nil 88 | } else if verified >= 8 { 89 | return errors.New("Unknown caveat present") 90 | } 91 | 92 | return errors.New("Required caveats not present") 93 | } 94 | 95 | func verifyExpiry(t string, now int) bool { 96 | expiry, err := strconv.Atoi(t) 97 | 98 | if err != nil { 99 | return false 100 | } 101 | return now < expiry 102 | } 103 | -------------------------------------------------------------------------------- /fclient/invitev3_test.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/matrix-org/gomatrixserverlib" 8 | "github.com/matrix-org/gomatrixserverlib/spec" 9 | ) 10 | 11 | func TestMarshalInviteV3Request(t *testing.T) { 12 | expected := `{"room_version":"org.matrix.msc4014","invite_room_state":[],"event":{"sender":"@test:localhost","room_id":"!19Mp0U9hjajeIiw1:localhost","type":"m.room.name","state_key":"","prev_events":["upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"],"auth_events":["abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY","X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko","k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"],"depth":7,"signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"content":{"name":"test3"}}}` 13 | 14 | senderID := "@test:localhost" 15 | roomID := "!19Mp0U9hjajeIiw1:localhost" 16 | eventType := "m.room.name" 17 | stateKey := "" 18 | prevEvents := []string{"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"} 19 | authEvents := []string{"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY", "X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko", "k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"} 20 | depth := int64(7) 21 | signatures := spec.RawJSON(`{"localhost": {"ed25519:u9kP": "5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}}`) 22 | content := spec.RawJSON(`{"name":"test3"}`) 23 | 24 | output := gomatrixserverlib.ProtoEvent{ 25 | SenderID: senderID, 26 | RoomID: roomID, 27 | Type: eventType, 28 | StateKey: &stateKey, 29 | PrevEvents: prevEvents, 30 | AuthEvents: authEvents, 31 | Depth: depth, 32 | Signature: signatures, 33 | Content: content, 34 | } 35 | 36 | inviteReq, err := NewInviteV3Request(output, gomatrixserverlib.RoomVersionPseudoIDs, []gomatrixserverlib.InviteStrippedState{}) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | 41 | j, err := json.Marshal(inviteReq) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | 46 | if string(j) != expected { 47 | t.Fatalf("\nresult: %q\nwanted: %q", string(j), expected) 48 | } 49 | 50 | var newRequest InviteV3Request 51 | err = json.Unmarshal(j, &newRequest) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | 56 | if newRequest.RoomVersion() != gomatrixserverlib.RoomVersionPseudoIDs { 57 | t.Fatalf("unmatched room version. expected: %v, got: %v", gomatrixserverlib.RoomVersionPseudoIDs, newRequest.RoomVersion()) 58 | } 59 | if len(newRequest.InviteRoomState()) != 0 { 60 | t.Fatalf("invite room state should not have any events") 61 | } 62 | if newRequest.Event().SenderID != senderID { 63 | t.Fatalf("unmatched senderID. expected: %v, got: %v", newRequest.Event().SenderID, senderID) 64 | 65 | } 66 | if newRequest.Event().RoomID != roomID { 67 | t.Fatalf("unmatched roomID. expected: %v, got: %v", newRequest.Event().RoomID, roomID) 68 | } 69 | if newRequest.Event().Type != eventType { 70 | t.Fatalf("unmatched type. expected: %v, got: %v", newRequest.Event().Type, eventType) 71 | 72 | } 73 | if *newRequest.Event().StateKey != stateKey { 74 | t.Fatalf("unmatched state key. expected: %v, got: %v", *newRequest.Event().StateKey, stateKey) 75 | } 76 | if newRequest.Event().Depth != depth { 77 | t.Fatalf("unmatched depth. expected: %v, got: %v", newRequest.Event().Depth, depth) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /spec/userid.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | const userSigil = '@' 10 | const localDomainSeparator = ':' 11 | 12 | var validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) 13 | 14 | // A UserID identifies a matrix user as per the matrix specification 15 | type UserID struct { 16 | raw string 17 | local string 18 | domain string 19 | } 20 | 21 | // Creates a new UserID, returning an error if invalid 22 | func NewUserID(id string, allowHistoricalIDs bool) (*UserID, error) { 23 | return parseAndValidateUserID(id, allowHistoricalIDs) 24 | } 25 | 26 | // Creates a new UserID, panicing if invalid 27 | func NewUserIDOrPanic(id string, allowHistoricalIDs bool) UserID { 28 | userID, err := parseAndValidateUserID(id, allowHistoricalIDs) 29 | if err != nil { 30 | panic(fmt.Sprintf("NewUserIDOrPanic failed: invalid user ID %s: %s", id, err.Error())) 31 | } 32 | return *userID 33 | } 34 | 35 | // Returns the full userID string including leading sigil 36 | func (user *UserID) String() string { 37 | return user.raw 38 | } 39 | 40 | // Returns just the localpart of the userID 41 | func (user *UserID) Local() string { 42 | return user.local 43 | } 44 | 45 | // Returns just the domain of the userID 46 | func (user *UserID) Domain() ServerName { 47 | return ServerName(user.domain) 48 | } 49 | 50 | func parseAndValidateUserID(id string, allowHistoricalIDs bool) (*UserID, error) { 51 | idLength := len(id) 52 | if idLength < 4 || idLength > 255 { // 4 since minimum userID includes an @, :, non-empty localpart, non-empty domain 53 | return nil, fmt.Errorf("length %d is not within the bounds 4-255", idLength) 54 | } 55 | if id[0] != userSigil { 56 | return nil, fmt.Errorf("first character is not '%c'", userSigil) 57 | } 58 | 59 | localpart, domain, found := strings.Cut(id[1:], string(localDomainSeparator)) 60 | if !found { 61 | return nil, fmt.Errorf("at least one '%c' is expected in the user id", localDomainSeparator) 62 | } 63 | if _, _, ok := ParseAndValidateServerName(ServerName(domain)); !ok { 64 | return nil, fmt.Errorf("domain is invalid") 65 | } 66 | 67 | if allowHistoricalIDs { 68 | // NOTE: Allowed historical userIDs: 69 | // https://spec.matrix.org/v1.4/appendices/#historical-user-ids 70 | if !historicallyValidCharacters(localpart) { 71 | return nil, fmt.Errorf("local part contains invalid characters from historical set") 72 | } 73 | } else { 74 | // NOTE: Allowed in the latest spec: 75 | // https://spec.matrix.org/v1.4/appendices/#user-identifiers 76 | if !validUsernameRegex.MatchString(localpart) { 77 | return nil, fmt.Errorf("local part contains invalid characters") 78 | } 79 | } 80 | 81 | userID := &UserID{ 82 | raw: id, 83 | local: localpart, 84 | domain: domain, 85 | } 86 | return userID, nil 87 | } 88 | 89 | func historicallyValidCharacters(localpart string) bool { 90 | // This check is currently not safe because Synapse has historically 91 | // not enforced these character ranges properly, so there are many 92 | // user IDs out in the wild that fall outside this (like with emoji). 93 | // TODO: This function needs to be room version aware, as this will be 94 | // fixed in a future room version. 95 | /* 96 | for _, r := range localpart { 97 | if r < 0x21 || r == 0x3A || r > 0x7E { 98 | return false 99 | } 100 | } 101 | */ 102 | 103 | return true 104 | } 105 | -------------------------------------------------------------------------------- /join.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | ) 9 | 10 | type FederatedJoinClient interface { 11 | MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res MakeJoinResponse, err error) 12 | SendJoin(ctx context.Context, origin, s spec.ServerName, event PDU) (res SendJoinResponse, err error) 13 | } 14 | 15 | type RestrictedRoomJoinInfo struct { 16 | LocalServerInRoom bool 17 | UserJoinedToRoom bool 18 | JoinedUsers []PDU 19 | } 20 | 21 | type MembershipQuerier interface { 22 | CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) 23 | } 24 | 25 | // RestrictedRoomJoinQuerier provides the information needed when processing a restricted room join request. 26 | type RestrictedRoomJoinQuerier interface { 27 | CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) 28 | InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) 29 | RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*RestrictedRoomJoinInfo, error) 30 | } 31 | 32 | type ProtoEvent struct { 33 | // The sender ID of the user sending the event. 34 | SenderID string `json:"sender"` 35 | // The room ID of the room this event is in. 36 | RoomID string `json:"room_id"` 37 | // The type of the event. 38 | Type string `json:"type"` 39 | // The state_key of the event if the event is a state event or nil if the event is not a state event. 40 | StateKey *string `json:"state_key,omitempty"` 41 | // The events that immediately preceded this event in the room history. This can be 42 | // either []eventReference for room v1/v2, and []string for room v3 onwards. 43 | PrevEvents interface{} `json:"prev_events"` 44 | // The events needed to authenticate this event. This can be 45 | // either []eventReference for room v1/v2, and []string for room v3 onwards. 46 | AuthEvents interface{} `json:"auth_events"` 47 | // The event ID of the event being redacted if this event is a "m.room.redaction". 48 | Redacts string `json:"redacts,omitempty"` 49 | // The depth of the event, This should be one greater than the maximum depth of the previous events. 50 | // The create event has a depth of 1. 51 | Depth int64 `json:"depth"` 52 | // The JSON object for "signatures" key of the event. 53 | Signature spec.RawJSON `json:"signatures,omitempty"` 54 | // The JSON object for "content" key of the event. 55 | Content spec.RawJSON `json:"content"` 56 | // The JSON object for the "unsigned" key 57 | Unsigned spec.RawJSON `json:"unsigned,omitempty"` 58 | 59 | Version IRoomVersion `json:"-"` // exclude this field 60 | } 61 | 62 | func (pe *ProtoEvent) SetContent(content interface{}) (err error) { 63 | pe.Content, err = json.Marshal(content) 64 | return 65 | } 66 | 67 | // SetUnsigned sets the JSON unsigned key of the event. 68 | func (pe *ProtoEvent) SetUnsigned(unsigned interface{}) (err error) { 69 | pe.Unsigned, err = json.Marshal(unsigned) 70 | return 71 | } 72 | 73 | type MakeJoinResponse interface { 74 | GetJoinEvent() ProtoEvent 75 | GetRoomVersion() RoomVersion 76 | } 77 | 78 | type SendJoinResponse interface { 79 | GetAuthEvents() EventJSONs 80 | GetStateEvents() EventJSONs 81 | GetOrigin() spec.ServerName 82 | GetJoinEvent() spec.RawJSON 83 | GetMembersOmitted() bool 84 | GetServersInRoom() []string 85 | } 86 | -------------------------------------------------------------------------------- /eventversion_test.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestEventIDForRoomVersionV1(t *testing.T) { 8 | initialEventJSON := `{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}` 9 | expectedEventID := "$yvN1b43rlmcOs5fY:localhost" 10 | 11 | event, err := MustGetRoomVersion(RoomVersionV1).NewEventFromTrustedJSON([]byte(initialEventJSON), false) 12 | if err != nil { 13 | t.Error(err) 14 | } 15 | 16 | if event.EventID() != expectedEventID { 17 | t.Fatalf("event ID '%s' does not match expected '%s'", event.EventID(), expectedEventID) 18 | } 19 | } 20 | 21 | func TestEventIDForRoomVersionV3(t *testing.T) { 22 | initialEventJSON := `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!uXDCzlYgCTHtiWCkEx:jki.re", "sender": "@erikj:jki.re", "content": {"room_version": "5", "predecessor": {"room_id": "!gdRMqOrTFdOCYHNwOo:half-shot.uk", "event_id": "$LP7ROBc4b+cMc1UE9haIz8q5AK2AIW4eJ90FfKLvyZI"}, "creator": "@erikj:jki.re"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "jki.re", "origin_server_ts": 1560284621137, "hashes": {"sha256": "IX6zuNiJpJPNf70BLleL3HSCpjKeq9Uhu7uUpyDjBmc"}, "signatures": {"jki.re": {"ed25519:auto": "O4IyFfF2PPtGp5uaDm8t57dZbdh8vc8Q64LgCwvzYRVItAMI0uisfiAFaxkVT7MRpzh6N2QNN5NMRXZKmgPYDA"}}, "unsigned": {"age": 1321650}}` 23 | expectedEventID := "$RrGxF28UrHLmoASHndYb9Jb/1SFww2ptmtur9INS438" 24 | 25 | event, err := MustGetRoomVersion(RoomVersionV3).NewEventFromTrustedJSON([]byte(initialEventJSON), false) 26 | if err != nil { 27 | t.Error(err) 28 | } 29 | 30 | if event.EventID() != expectedEventID { 31 | t.Fatalf("event ID '%s' does not match expected '%s'", event.EventID(), expectedEventID) 32 | } 33 | } 34 | 35 | func TestEventIDForRoomVersionV4(t *testing.T) { 36 | initialEventJSON := `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!uXDCzlYgCTHtiWCkEx:jki.re", "sender": "@erikj:jki.re", "content": {"room_version": "5", "predecessor": {"room_id": "!gdRMqOrTFdOCYHNwOo:half-shot.uk", "event_id": "$LP7ROBc4b+cMc1UE9haIz8q5AK2AIW4eJ90FfKLvyZI"}, "creator": "@erikj:jki.re"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "jki.re", "origin_server_ts": 1560284621137, "hashes": {"sha256": "IX6zuNiJpJPNf70BLleL3HSCpjKeq9Uhu7uUpyDjBmc"}, "signatures": {"jki.re": {"ed25519:auto": "O4IyFfF2PPtGp5uaDm8t57dZbdh8vc8Q64LgCwvzYRVItAMI0uisfiAFaxkVT7MRpzh6N2QNN5NMRXZKmgPYDA"}}, "unsigned": {"age": 1321650}}` 37 | expectedEventID := "$RrGxF28UrHLmoASHndYb9Jb_1SFww2ptmtur9INS438" 38 | 39 | event, err := MustGetRoomVersion(RoomVersionV4).NewEventFromTrustedJSON([]byte(initialEventJSON), false) 40 | if err != nil { 41 | t.Error(err) 42 | } 43 | 44 | if event.EventID() != expectedEventID { 45 | t.Fatalf("event ID '%s' does not match expected '%s'", event.EventID(), expectedEventID) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /invite_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The Matrix.org Foundation C.I.C. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gomatrixserverlib 16 | 17 | import ( 18 | "bytes" 19 | "encoding/json" 20 | "testing" 21 | 22 | "github.com/matrix-org/gomatrixserverlib/spec" 23 | ) 24 | 25 | const TestInviteV2ExampleEvent = `{"_room_version":"1","auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}` 26 | 27 | func TestEmptyUnsignedFieldIsSetForPDU(t *testing.T) { 28 | output, err := NewEventFromHeaderedJSON([]byte(TestInviteV2ExampleEvent), false) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | inviteState := []InviteStrippedState{} 34 | 35 | err = setUnsignedFieldForInvite(output, inviteState) 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | 40 | inviteStateJSON, err := json.Marshal(map[string]interface{}{"invite_room_state": struct{}{}}) 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | 45 | if !bytes.Equal(output.Unsigned(), inviteStateJSON) { 46 | t.Fatalf("Expected: %v, Got: %v", string(inviteStateJSON[:]), string(output.Unsigned()[:])) 47 | } 48 | } 49 | 50 | func TestEmptyUnsignedFieldIsSetForProtoEvent(t *testing.T) { 51 | senderID := "@test:localhost" 52 | roomID := "!19Mp0U9hjajeIiw1:localhost" 53 | eventType := "m.room.name" 54 | stateKey := "" 55 | prevEvents := []string{"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"} 56 | authEvents := []string{"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY", "X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko", "k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"} 57 | depth := int64(7) 58 | signatures := spec.RawJSON(`{"localhost": {"ed25519:u9kP": "5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}}`) 59 | content := spec.RawJSON(`{"name":"test3"}`) 60 | 61 | output := ProtoEvent{ 62 | SenderID: senderID, 63 | RoomID: roomID, 64 | Type: eventType, 65 | StateKey: &stateKey, 66 | PrevEvents: prevEvents, 67 | AuthEvents: authEvents, 68 | Depth: depth, 69 | Signature: signatures, 70 | Content: content, 71 | } 72 | 73 | inviteState := []InviteStrippedState{} 74 | 75 | err := setUnsignedFieldForProtoInvite(&output, inviteState) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | inviteStateJSON, err := json.Marshal(map[string]interface{}{"invite_room_state": struct{}{}}) 81 | if err != nil { 82 | t.Fatal(err) 83 | } 84 | 85 | if !bytes.Equal(output.Unsigned, inviteStateJSON) { 86 | t.Fatalf("Expected: %v, Got: %v", string(inviteStateJSON[:]), string(output.Unsigned[:])) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /spec/userid_test.go: -------------------------------------------------------------------------------- 1 | package spec_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | ) 9 | 10 | const defaultDomain = "domain" 11 | const defaultLocalpart = "localpart" 12 | 13 | func TestValidUserIDs(t *testing.T) { 14 | tests := map[string]struct { 15 | localpart string 16 | domain string 17 | allowHistoricIDs bool 18 | }{ 19 | "basic": {localpart: defaultLocalpart, domain: defaultDomain, allowHistoricIDs: false}, 20 | "extensive_local": {localpart: "abcdefghijklmnopqrstuvwxyz0123456789._=-/", domain: defaultDomain, allowHistoricIDs: false}, 21 | "extensive_local_historic": {localpart: "!\"#$%&'()*+,-./0123456789;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~", domain: defaultDomain, allowHistoricIDs: true}, 22 | "domain_with_port": {localpart: defaultLocalpart, domain: "domain.org:80", allowHistoricIDs: false}, 23 | "minimum_id": {localpart: "a", domain: "1", allowHistoricIDs: false}, 24 | } 25 | 26 | for name, tc := range tests { 27 | t.Run(name, func(t *testing.T) { 28 | raw := fmt.Sprintf("@%s:%s", tc.localpart, tc.domain) 29 | 30 | userID, err := spec.NewUserID(raw, tc.allowHistoricIDs) 31 | 32 | if err != nil { 33 | t.Fatalf("valid userID should not fail: %s", err.Error()) 34 | } 35 | if userID.Local() != tc.localpart { 36 | t.Fatalf("Localpart - Expected: %s Actual: %s ", tc.localpart, userID.Local()) 37 | } 38 | if userID.Domain() != spec.ServerName(tc.domain) { 39 | t.Fatalf("Domain - Expected: %s Actual: %s ", spec.ServerName(tc.domain), userID.Domain()) 40 | } 41 | if userID.String() != raw { 42 | t.Fatalf("Raw - Expected: %s Actual: %s ", raw, userID.String()) 43 | } 44 | }) 45 | } 46 | } 47 | 48 | func TestInvalidUserIDs(t *testing.T) { 49 | tests := map[string]struct { 50 | rawUserID string 51 | }{ 52 | "empty": {rawUserID: ""}, 53 | "no_leading_@": {rawUserID: "localpart:domain"}, 54 | "no_colon": {rawUserID: "@localpartdomain"}, 55 | "invalid_local_chars": {rawUserID: "@local&part:domain"}, 56 | "invalid_domain_chars": {rawUserID: "@localpart:domain/"}, 57 | "no_local": {rawUserID: "@:domain"}, 58 | "no_domain": {rawUserID: "@localpart:"}, 59 | "too_long": {rawUserID: func() string { 60 | userID := "@a:" 61 | domain := "" 62 | for i := 0; i < 255-len(userID)+1; i++ { 63 | domain = domain + "a" 64 | } 65 | 66 | raw := userID + domain 67 | 68 | if len(raw) <= 255 { 69 | t.Fatalf("ensure the userid is greater than 255 (is %d) characters for this test", len(raw)) 70 | } 71 | return raw 72 | }()}, 73 | } 74 | 75 | for name, tc := range tests { 76 | t.Run(name, func(t *testing.T) { 77 | _, err := spec.NewUserID(tc.rawUserID, false) 78 | 79 | if err == nil { 80 | t.Fatalf("userID is not valid, it shouldn't parse") 81 | } 82 | }) 83 | } 84 | } 85 | 86 | func TestSameUserIDsAreEqual(t *testing.T) { 87 | id := "@localpart:domain" 88 | 89 | userID, err := spec.NewUserID(id, false) 90 | userID2, err2 := spec.NewUserID(id, false) 91 | 92 | if err != nil || err2 != nil { 93 | t.Fatalf("userID is valid, it should parse") 94 | } 95 | 96 | if *userID != *userID2 { 97 | t.Fatalf("userIDs should be equal") 98 | } 99 | } 100 | 101 | func TestDifferentUserIDsAreNotEqual(t *testing.T) { 102 | id := "@localpart:domain" 103 | id2 := "@localpart2:domain" 104 | 105 | userID, err := spec.NewUserID(id, false) 106 | userID2, err2 := spec.NewUserID(id2, false) 107 | 108 | if err != nil || err2 != nil { 109 | t.Fatalf("userID is valid, it should parse") 110 | } 111 | 112 | if *userID == *userID2 { 113 | t.Fatalf("userIDs shouldn't be equal") 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /event.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import ( 19 | "fmt" 20 | "strings" 21 | "unicode/utf8" 22 | 23 | "github.com/matrix-org/gomatrixserverlib/spec" 24 | ) 25 | 26 | // Event validation errors 27 | const ( 28 | EventValidationTooLarge int = 1 29 | ) 30 | 31 | // EventValidationError is returned if there is a problem validating an event 32 | type EventValidationError struct { 33 | Message string 34 | Code int 35 | Persistable bool 36 | } 37 | 38 | func (e EventValidationError) Error() string { 39 | return e.Message 40 | } 41 | 42 | type eventFields struct { 43 | RoomID string `json:"room_id"` 44 | SenderID string `json:"sender"` 45 | Type string `json:"type"` 46 | StateKey *string `json:"state_key"` 47 | Content spec.RawJSON `json:"content"` 48 | Redacts string `json:"redacts"` 49 | Depth int64 `json:"depth"` 50 | Unsigned spec.RawJSON `json:"unsigned,omitempty"` 51 | OriginServerTS spec.Timestamp `json:"origin_server_ts"` 52 | //Origin spec.ServerName `json:"origin"` 53 | } 54 | 55 | var emptyEventReferenceList = []eventReference{} 56 | 57 | const ( 58 | // The event ID, room ID, sender, event type and state key fields cannot be 59 | // bigger than this. 60 | // https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/event_auth.py#L173-L182 61 | maxIDLength = 255 62 | // The entire event JSON, including signatures cannot be bigger than this. 63 | // https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/event_auth.py#L183-184 64 | maxEventLength = 65536 65 | ) 66 | 67 | func checkID(id, kind string, sigil byte) (err error) { 68 | if _, err = domainFromID(id); err != nil { 69 | return 70 | } 71 | if id[0] != sigil { 72 | err = fmt.Errorf( 73 | "gomatrixserverlib: invalid %s ID, wanted first byte to be '%c' got '%c'", 74 | kind, sigil, id[0], 75 | ) 76 | return 77 | } 78 | if l := utf8.RuneCountInString(id); l > maxIDLength { 79 | err = EventValidationError{ 80 | Code: EventValidationTooLarge, 81 | Message: fmt.Sprintf("gomatrixserverlib: %s ID is too long, length %d > maximum %d", kind, l, maxIDLength), 82 | } 83 | return 84 | } 85 | if l := len(id); l > maxIDLength { 86 | err = EventValidationError{ 87 | Code: EventValidationTooLarge, 88 | Message: fmt.Sprintf("gomatrixserverlib: %s ID is too long, length %d bytes > maximum %d bytes", kind, l, maxIDLength), 89 | Persistable: true, 90 | } 91 | return 92 | } 93 | return 94 | } 95 | 96 | // SplitID splits a matrix ID into a local part and a server name. 97 | func SplitID(sigil byte, id string) (local string, domain spec.ServerName, err error) { 98 | // IDs have the format: SIGIL LOCALPART ":" DOMAIN 99 | // Split on the first ":" character since the domain can contain ":" 100 | // characters. 101 | if len(id) == 0 || id[0] != sigil { 102 | return "", "", fmt.Errorf("gomatrixserverlib: invalid ID %q doesn't start with %q", id, sigil) 103 | } 104 | parts := strings.SplitN(id, ":", 2) 105 | if len(parts) != 2 { 106 | // The ID must have a ":" character. 107 | return "", "", fmt.Errorf("gomatrixserverlib: invalid ID %q missing ':'", id) 108 | } 109 | return parts[0][1:], spec.ServerName(parts[1]), nil 110 | } 111 | -------------------------------------------------------------------------------- /spec/base64.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package spec 17 | 18 | import ( 19 | "database/sql/driver" 20 | "encoding/base64" 21 | "encoding/json" 22 | "fmt" 23 | "strings" 24 | ) 25 | 26 | // A Base64Bytes is a string of bytes (not base64 encoded) that are 27 | // base64 encoded when used in JSON. 28 | // 29 | // The bytes encoded using base64 when marshalled as JSON. 30 | // When the bytes are unmarshalled from JSON they are decoded from base64. 31 | // 32 | // When scanning directly from a database, a string column will be 33 | // decoded from base64 automatically whereas a bytes column will be 34 | // copied as-is. 35 | type Base64Bytes []byte 36 | 37 | // Encode encodes the bytes as base64 38 | func (b64 Base64Bytes) Encode() string { 39 | return base64.RawStdEncoding.EncodeToString(b64) 40 | } 41 | 42 | // Decode decodes the given input into this Base64Bytes 43 | func (b64 *Base64Bytes) Decode(str string) error { 44 | // We must check whether the string was encoded in a URL-safe way in order 45 | // to use the appropriate encoding. 46 | var err error 47 | if strings.ContainsAny(str, "-_") { 48 | *b64, err = base64.RawURLEncoding.DecodeString(str) 49 | } else { 50 | *b64, err = base64.RawStdEncoding.DecodeString(str) 51 | } 52 | return err 53 | } 54 | 55 | // Implements sql.Scanner 56 | func (b64 *Base64Bytes) Scan(src interface{}) error { 57 | switch v := src.(type) { 58 | case string: 59 | return b64.Decode(v) 60 | case []byte: 61 | *b64 = append(Base64Bytes{}, v...) 62 | return nil 63 | case RawJSON: 64 | return b64.UnmarshalJSON(v) 65 | default: 66 | return fmt.Errorf("unsupported source type") 67 | } 68 | } 69 | 70 | // Implements sql.Valuer 71 | func (b64 Base64Bytes) Value() (driver.Value, error) { 72 | return b64.Encode(), nil 73 | } 74 | 75 | // MarshalJSON encodes the bytes as base64 and then encodes the base64 as a JSON string. 76 | // This takes a value receiver so that maps and slices of Base64Bytes encode correctly. 77 | func (b64 Base64Bytes) MarshalJSON() ([]byte, error) { 78 | // This could be made more efficient by using base64.RawStdEncoding.Encode 79 | // to write the base64 directly to the JSON. We don't need to JSON escape 80 | // any of the characters used in base64. 81 | return json.Marshal(b64.Encode()) 82 | } 83 | 84 | // MarshalYAML implements yaml.Marshaller 85 | // It just encodes the bytes as base64, which is a valid YAML string 86 | func (b64 Base64Bytes) MarshalYAML() (interface{}, error) { 87 | return b64.Encode(), nil 88 | } 89 | 90 | // UnmarshalJSON decodes a JSON string and then decodes the resulting base64. 91 | // This takes a pointer receiver because it needs to write the result of decoding. 92 | func (b64 *Base64Bytes) UnmarshalJSON(raw []byte) (err error) { 93 | // We could add a fast path that used base64.RawStdEncoding.Decode 94 | // directly on the raw JSON if the JSON didn't contain any escapes. 95 | var str string 96 | if err = json.Unmarshal(raw, &str); err != nil { 97 | return 98 | } 99 | err = b64.Decode(str) 100 | return 101 | } 102 | 103 | // UnmarshalYAML implements yaml.Unmarshaller 104 | // it unmarshals the input as a yaml string and then base64-decodes the result 105 | func (b64 *Base64Bytes) UnmarshalYAML(unmarshal func(interface{}) error) (err error) { 106 | var str string 107 | if err = unmarshal(&str); err != nil { 108 | return 109 | } 110 | err = b64.Decode(str) 111 | return 112 | } 113 | -------------------------------------------------------------------------------- /spec/eventtypes.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | const ( 4 | // Join is the string constant "join" 5 | Join = "join" 6 | // Ban is the string constant "ban" 7 | Ban = "ban" 8 | // Leave is the string constant "leave" 9 | Leave = "leave" 10 | // Invite is the string constant "invite" 11 | Invite = "invite" 12 | // Knock is the string constant "knock" 13 | Knock = "knock" 14 | // Restricted is the string constant "restricted" 15 | Restricted = "restricted" 16 | // NOTSPEC: Restricted is the string constant "knock_restricted" (MSC3787) 17 | // REVIEW: the MSC is merged though... so is this specced? Idk. 18 | KnockRestricted = "knock_restricted" 19 | // NOTSPEC: Peek is the string constant "peek" (MSC2753, used as the label in the sync block) 20 | Peek = "peek" 21 | // Public is the string constant "public" 22 | Public = "public" 23 | // WorldReadable is the string constant "world_readable" 24 | WorldReadable = "world_readable" 25 | // Room creation preset enum used to create private rooms 26 | PresetPrivateChat = "private_chat" 27 | // Room creation preset enum used to create trusted private rooms 28 | PresetTrustedPrivateChat = "trusted_private_chat" 29 | // Room creation preset enum used to create public rooms 30 | PresetPublicChat = "public_chat" 31 | // MRoomCreate https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-create 32 | MRoomCreate = "m.room.create" 33 | // MRoomJoinRules https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-join-rules 34 | MRoomJoinRules = "m.room.join_rules" 35 | // MRoomPowerLevels https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels 36 | MRoomPowerLevels = "m.room.power_levels" 37 | // MRoomName https://matrix.org/docs/spec/client_server/r0.6.0#m-room-name 38 | MRoomName = "m.room.name" 39 | // MRoomTopic https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-topic 40 | MRoomTopic = "m.room.topic" 41 | // MRoomAvatar https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-avatar 42 | MRoomAvatar = "m.room.avatar" 43 | // MRoomMember https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member 44 | MRoomMember = "m.room.member" 45 | // MRoomThirdPartyInvite https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-third-party-invite 46 | MRoomThirdPartyInvite = "m.room.third_party_invite" 47 | // MRoomAliases https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-aliases 48 | MRoomAliases = "m.room.aliases" 49 | // MRoomCanonicalAlias https://matrix.org/docs/spec/client_server/r0.6.0#m-room-canonical-alias 50 | MRoomCanonicalAlias = "m.room.canonical_alias" 51 | // MRoomHistoryVisibility https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-history-visibility 52 | MRoomHistoryVisibility = "m.room.history_visibility" 53 | // MRoomGuestAccess https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-guest-access 54 | MRoomGuestAccess = "m.room.guest_access" 55 | // MRoomEncryption https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-encryption 56 | MRoomEncryption = "m.room.encryption" 57 | // MRoomRedaction https://matrix.org/docs/spec/client_server/r0.2.0.html#id21 58 | MRoomRedaction = "m.room.redaction" 59 | // MTyping https://matrix.org/docs/spec/client_server/r0.3.0.html#m-typing 60 | MTyping = "m.typing" 61 | // MDirectToDevice https://matrix.org/docs/spec/server_server/r0.1.3#send-to-device-messaging 62 | MDirectToDevice = "m.direct_to_device" 63 | // MDeviceListUpdate https://matrix.org/docs/spec/server_server/latest#m-device-list-update-schema 64 | MDeviceListUpdate = "m.device_list_update" 65 | // MReceipt https://matrix.org/docs/spec/server_server/r0.1.4#receipts 66 | MReceipt = "m.receipt" 67 | // MPresence https://matrix.org/docs/spec/server_server/latest#m-presence-schema 68 | MPresence = "m.presence" 69 | // MRoomMembership https://github.com/matrix-org/matrix-doc/blob/clokep/restricted-rooms/proposals/3083-restricted-rooms.md 70 | MRoomMembership = "m.room_membership" 71 | // MSpaceChild https://spec.matrix.org/v1.7/client-server-api/#mspacechild-relationship 72 | MSpaceChild = "m.space.child" 73 | // MSpaceParent https://spec.matrix.org/v1.7/client-server-api/#mspaceparent-relationships 74 | MSpaceParent = "m.space.parent" 75 | ) 76 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | # Run golangci-lint 15 | lint: 16 | timeout-minutes: 5 17 | name: Linting 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Install Go 22 | uses: actions/setup-go@v3 23 | with: 24 | go-version: 1.23 25 | - name: golangci-lint 26 | uses: golangci/golangci-lint-action@v3 27 | 28 | # run go test with different go versions 29 | test: 30 | timeout-minutes: 5 31 | name: Unit tests (Go ${{ matrix.go }}) 32 | runs-on: ubuntu-latest 33 | strategy: 34 | fail-fast: false 35 | matrix: 36 | go: ["stable", "1.23"] 37 | steps: 38 | - uses: actions/checkout@v3 39 | - name: Setup go 40 | uses: actions/setup-go@v3 41 | with: 42 | go-version: ${{ matrix.go }} 43 | - uses: actions/cache@v3 44 | with: 45 | path: | 46 | ~/.cache/go-build 47 | ~/go/pkg/mod 48 | key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} 49 | restore-keys: | 50 | ${{ runner.os }}-go${{ matrix.go }}-test- 51 | - run: go test -race -coverpkg=./... -coverprofile=cover.out $(go list ./...) 52 | - name: Upload coverage to Codecov 53 | uses: codecov/codecov-action@v4 54 | with: 55 | fail_ci_if_error: true 56 | token: ${{ secrets.CODECOV_TOKEN }} 57 | 58 | # run go test on Dendrite with different go versions 59 | test-dendrite: 60 | timeout-minutes: 10 61 | name: Unit tests Dendrite (Go ${{ matrix.go }}) 62 | runs-on: ubuntu-latest 63 | # Service containers to run with `container-job` 64 | services: 65 | # Label used to access the service container 66 | postgres: 67 | # Docker Hub image 68 | image: postgres:13-alpine 69 | # Provide the password for postgres 70 | env: 71 | POSTGRES_USER: postgres 72 | POSTGRES_PASSWORD: postgres 73 | POSTGRES_DB: dendrite 74 | ports: 75 | # Maps tcp port 5432 on service container to the host 76 | - 5432:5432 77 | # Set health checks to wait until postgres has started 78 | options: >- 79 | --health-cmd pg_isready 80 | --health-interval 10s 81 | --health-timeout 5s 82 | --health-retries 5 83 | strategy: 84 | fail-fast: false 85 | matrix: 86 | go: ["stable", "1.23"] 87 | steps: 88 | - uses: actions/checkout@v3 89 | with: 90 | repository: "element-hq/dendrite" 91 | - name: Install libolm 92 | run: sudo apt-get install libolm-dev libolm3 93 | - name: Setup go 94 | uses: actions/setup-go@v3 95 | with: 96 | go-version: ${{ matrix.go }} 97 | - uses: actions/cache@v3 98 | with: 99 | path: | 100 | ~/.cache/go-build 101 | ~/go/pkg/mod 102 | key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} 103 | restore-keys: | 104 | ${{ runner.os }}-go${{ matrix.go }}-test- 105 | - if: github.event_name == 'pull_request' 106 | env: 107 | REPOSITORY: ${{ github.event.pull_request.head.repo.full_name }} 108 | PULL_SHA: ${{ github.event.pull_request.head.sha }} 109 | # Replace matrix-org/gomatrixserverlib with the repository sending the pull request 110 | run: go mod edit -replace "github.com/matrix-org/gomatrixserverlib=github.com/${REPOSITORY}@${PULL_SHA}" && go mod tidy 111 | - if: github.ref_name == 'main' 112 | run: go get github.com/matrix-org/gomatrixserverlib@${{ github.sha }} && go mod tidy 113 | - run: go test ./... 114 | env: 115 | POSTGRES_HOST: localhost 116 | POSTGRES_USER: postgres 117 | POSTGRES_PASSWORD: postgres 118 | POSTGRES_DB: dendrite 119 | -------------------------------------------------------------------------------- /fclient/crosssigning.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Matrix.org Foundation C.I.C. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package fclient 16 | 17 | import ( 18 | "bytes" 19 | "encoding/json" 20 | "slices" 21 | 22 | "github.com/matrix-org/gomatrixserverlib" 23 | "github.com/matrix-org/gomatrixserverlib/spec" 24 | "github.com/tidwall/gjson" 25 | ) 26 | 27 | type CrossSigningKeyPurpose string 28 | 29 | const ( 30 | CrossSigningKeyPurposeMaster CrossSigningKeyPurpose = "master" 31 | CrossSigningKeyPurposeSelfSigning CrossSigningKeyPurpose = "self_signing" 32 | CrossSigningKeyPurposeUserSigning CrossSigningKeyPurpose = "user_signing" 33 | ) 34 | 35 | type CrossSigningKeys struct { 36 | MasterKey CrossSigningKey `json:"master_key"` 37 | SelfSigningKey CrossSigningKey `json:"self_signing_key"` 38 | UserSigningKey CrossSigningKey `json:"user_signing_key"` 39 | } 40 | 41 | // https://spec.matrix.org/unstable/client-server-api/#post_matrixclientr0keysdevice_signingupload 42 | type CrossSigningKey struct { 43 | Signatures map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes `json:"signatures,omitempty"` 44 | Keys map[gomatrixserverlib.KeyID]spec.Base64Bytes `json:"keys"` 45 | Usage []CrossSigningKeyPurpose `json:"usage"` 46 | UserID string `json:"user_id"` 47 | } 48 | 49 | func (s *CrossSigningKey) isCrossSigningBody() {} // implements CrossSigningBody 50 | 51 | func (s *CrossSigningKey) Equal(other *CrossSigningKey) bool { 52 | if s == nil || other == nil { 53 | return false 54 | } 55 | if s.UserID != other.UserID { 56 | return false 57 | } 58 | if len(s.Usage) != len(other.Usage) { 59 | return false 60 | } 61 | 62 | // Make sure the slices are sorted before we compare them. 63 | if !slices.IsSorted(s.Usage) { 64 | slices.Sort(s.Usage) 65 | } 66 | if !slices.IsSorted(other.Usage) { 67 | slices.Sort(other.Usage) 68 | } 69 | for i := range s.Usage { 70 | if s.Usage[i] != other.Usage[i] { 71 | return false 72 | } 73 | } 74 | if len(s.Keys) != len(other.Keys) { 75 | return false 76 | } 77 | for k, v := range s.Keys { 78 | if !bytes.Equal(other.Keys[k], v) { 79 | return false 80 | } 81 | } 82 | if len(s.Signatures) != len(other.Signatures) { 83 | return false 84 | } 85 | for k, v := range s.Signatures { 86 | otherV, ok := other.Signatures[k] 87 | if !ok { 88 | return false 89 | } 90 | if len(v) != len(otherV) { 91 | return false 92 | } 93 | for k2, v2 := range v { 94 | if !bytes.Equal(otherV[k2], v2) { 95 | return false 96 | } 97 | } 98 | } 99 | return true 100 | } 101 | 102 | type CrossSigningBody interface { 103 | isCrossSigningBody() 104 | } 105 | 106 | type CrossSigningForKeyOrDevice struct { 107 | CrossSigningBody 108 | } 109 | 110 | // Implements json.Marshaler 111 | func (c CrossSigningForKeyOrDevice) MarshalJSON() ([]byte, error) { 112 | // Marshal the contents at the top level, rather than having it embedded 113 | // in a "CrossSigningBody" JSON key. 114 | return json.Marshal(c.CrossSigningBody) 115 | } 116 | 117 | // Implements json.Unmarshaler 118 | func (c *CrossSigningForKeyOrDevice) UnmarshalJSON(b []byte) error { 119 | if gjson.GetBytes(b, "device_id").Exists() { 120 | body := &DeviceKeys{} 121 | if err := json.Unmarshal(b, body); err != nil { 122 | return err 123 | } 124 | c.CrossSigningBody = body 125 | return nil 126 | } 127 | body := &CrossSigningKey{} 128 | if err := json.Unmarshal(b, body); err != nil { 129 | return err 130 | } 131 | c.CrossSigningBody = body 132 | return nil 133 | } 134 | -------------------------------------------------------------------------------- /fclient/well_known.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/matrix-org/gomatrixserverlib/spec" 15 | ) 16 | 17 | var ( 18 | errNoWellKnown = errors.New("No .well-known found") 19 | ) 20 | 21 | const WellKnownMaxSize = 50 * 1024 // 50KB 22 | 23 | // WellKnownResult is the result of looking up a matrix server's well-known file. 24 | // Located at https:///.well-known/matrix/server 25 | type WellKnownResult struct { 26 | NewAddress spec.ServerName `json:"m.server"` 27 | CacheExpiresAt int64 28 | } 29 | 30 | // LookupWellKnown looks up a well-known record for a matrix server. If one if 31 | // found, it returns the server to redirect to. 32 | func LookupWellKnown(ctx context.Context, serverNameType spec.ServerName) (*WellKnownResult, error) { 33 | serverName := string(serverNameType) 34 | 35 | // Handle ending "/" 36 | serverName = strings.TrimRight(serverName, "/") 37 | 38 | wellKnownPath := "/.well-known/matrix/server" 39 | 40 | // Request server's well-known record 41 | req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+wellKnownPath, nil) 42 | if err != nil { 43 | return nil, err 44 | } 45 | // Given well-known should be quite small and fast to fetch, timeout the request after 30s. 46 | client := http.Client{Timeout: time.Second * 30} 47 | resp, err := client.Do(req) 48 | if err != nil { 49 | return nil, err 50 | } 51 | defer func() { 52 | _ = resp.Body.Close() 53 | }() 54 | if resp.StatusCode != 200 { 55 | return nil, errNoWellKnown 56 | } 57 | 58 | // If the remote server reports a Content-Length to us then make sure 59 | // that the well-known response size doesn't exceed WellKnownMaxSize. 60 | contentLengthHeader := resp.Header.Get("Content-Length") 61 | if l, err := strconv.Atoi(contentLengthHeader); err == nil && l > WellKnownMaxSize { 62 | return nil, fmt.Errorf("well-known content length %d exceeds %d bytes", l, WellKnownMaxSize) 63 | } 64 | 65 | // Figure out when the cache expiry time of this well-known record is 66 | cacheControlHeader := resp.Header.Get("Cache-Control") 67 | expiresHeader := resp.Header.Get("Expires") 68 | 69 | expiryTimestamp := int64(0) 70 | 71 | if expiresHeader != "" { 72 | // parse the HTTP-date (RFC7231 section 7.1.1.1) 73 | // Mon Jan 2 15:04:05 -0700 MST 2006 74 | referenceTimeFormat := "Mon, 02 Jan 2006 15:04:05 MST" 75 | expiresTime, err := time.Parse(referenceTimeFormat, expiresHeader) 76 | if err == nil { 77 | expiryTimestamp = expiresTime.Unix() 78 | } 79 | } 80 | 81 | // According to RFC7234 section 5.3, Cache-Control with max-age directive 82 | // MUST be preferred to Expires header. 83 | if cacheControlHeader != "" { 84 | kvPairs := strings.Split(cacheControlHeader, ",") 85 | for _, keyValuePair := range kvPairs { 86 | keyValuePair = strings.Trim(keyValuePair, " ") 87 | pieces := strings.SplitN(keyValuePair, "=", 2) 88 | if len(pieces) == 2 && strings.EqualFold(pieces[0], "max-age") { 89 | // max-age is the (maximum) number of seconds this record can 90 | // be assumed to live 91 | stringValue := pieces[1] 92 | age, err := strconv.ParseInt(stringValue, 10, 64) 93 | 94 | if err == nil { 95 | expiryTimestamp = age + time.Now().Unix() 96 | } 97 | } 98 | } 99 | } 100 | 101 | // By this point we hope that we've caught any huge well-known records 102 | // by checking Content-Length, but it's possible that header will be 103 | // missing. Better to be safe than sorry by reading no more than the 104 | // WellKnownMaxSize in any case. 105 | body, err := io.ReadAll(&io.LimitedReader{R: resp.Body, N: WellKnownMaxSize}) 106 | if err != nil { 107 | return nil, err 108 | } 109 | 110 | // Convert result to JSON 111 | wellKnownResponse := &WellKnownResult{ 112 | CacheExpiresAt: expiryTimestamp, 113 | } 114 | err = json.Unmarshal(body, wellKnownResponse) 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | if wellKnownResponse.NewAddress == "" { 120 | return nil, errors.New("No m.server key found in well-known response") 121 | } 122 | 123 | // Return result 124 | return wellKnownResponse, nil 125 | } 126 | -------------------------------------------------------------------------------- /handleleave.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The Matrix.org Foundation C.I.C. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gomatrixserverlib 16 | 17 | import ( 18 | "fmt" 19 | 20 | "github.com/matrix-org/gomatrixserverlib/spec" 21 | ) 22 | 23 | type HandleMakeLeaveResponse struct { 24 | LeaveTemplateEvent ProtoEvent 25 | RoomVersion RoomVersion 26 | } 27 | 28 | type HandleMakeLeaveInput struct { 29 | UserID spec.UserID // The user wanting to leave the room 30 | SenderID spec.SenderID // The senderID of the user wanting to leave the room 31 | RoomID spec.RoomID // The room the user wants to leave 32 | RoomVersion RoomVersion // The room version for the room being left 33 | RequestOrigin spec.ServerName // The server that sent the /make_leave federation request 34 | LocalServerName spec.ServerName // The name of this local server 35 | LocalServerInRoom bool // Whether this local server has a user currently joined to the room 36 | UserIDQuerier spec.UserIDForSender // Provides userIDs given a senderID 37 | 38 | // Returns a fully built version of the proto event and a list of state events required to auth this event 39 | BuildEventTemplate func(*ProtoEvent) (PDU, []PDU, error) 40 | } 41 | 42 | func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, error) { 43 | 44 | if input.UserID.Domain() != input.RequestOrigin { 45 | return nil, spec.Forbidden(fmt.Sprintf("The leave must be sent by the server of the user. Origin %s != %s", 46 | input.RequestOrigin, input.UserID.Domain())) 47 | } 48 | 49 | // Check if we think we are still joined to the room 50 | if !input.LocalServerInRoom { 51 | return nil, spec.NotFound(fmt.Sprintf("Local server not currently joined to room: %s", input.RoomID.String())) 52 | } 53 | 54 | // Try building an event for the server 55 | rawSenderID := string(input.SenderID) 56 | proto := ProtoEvent{ 57 | SenderID: string(input.SenderID), 58 | RoomID: input.RoomID.String(), 59 | Type: spec.MRoomMember, 60 | StateKey: &rawSenderID, 61 | } 62 | content := MemberContent{ 63 | Membership: spec.Leave, 64 | } 65 | 66 | if err := proto.SetContent(content); err != nil { 67 | return nil, spec.InternalServerError{Err: "builder.SetContent failed"} 68 | } 69 | 70 | event, stateEvents, templateErr := input.BuildEventTemplate(&proto) 71 | if templateErr != nil { 72 | return nil, templateErr 73 | } 74 | if event == nil { 75 | return nil, spec.InternalServerError{Err: "template builder returned nil event"} 76 | } 77 | if stateEvents == nil { 78 | return nil, spec.InternalServerError{Err: "template builder returned nil event state"} 79 | } 80 | if event.Type() != spec.MRoomMember { 81 | return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())} 82 | } 83 | 84 | provider, err := NewAuthEvents(stateEvents) 85 | if err != nil { 86 | return nil, spec.Forbidden(err.Error()) 87 | } 88 | if err = Allowed(event, provider, input.UserIDQuerier); err != nil { 89 | return nil, spec.Forbidden(err.Error()) 90 | } 91 | 92 | // This ensures we send EventReferences for room version v1 and v2. We need to do this, since we're 93 | // returning the proto event, which isn't modified when running `Build`. 94 | switch event.Version() { 95 | case RoomVersionV1, RoomVersionV2: 96 | proto.PrevEvents = toEventReference(event.PrevEventIDs()) 97 | proto.AuthEvents = toEventReference(event.AuthEventIDs()) 98 | } 99 | 100 | makeLeaveResponse := HandleMakeLeaveResponse{ 101 | LeaveTemplateEvent: proto, 102 | RoomVersion: input.RoomVersion, 103 | } 104 | return &makeLeaveResponse, nil 105 | } 106 | -------------------------------------------------------------------------------- /tokens/tokens.go: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | 13 | package tokens 14 | 15 | import ( 16 | "encoding/base64" 17 | "errors" 18 | "fmt" 19 | "strconv" 20 | "time" 21 | 22 | macaroon "gopkg.in/macaroon.v2" 23 | ) 24 | 25 | const ( 26 | macaroonVersion = macaroon.V2 27 | defaultDuration = 2 * 60 28 | // UserPrefix is a common prefix for every user_id caveat 29 | UserPrefix = "user_id = " 30 | // TimePrefix is a common prefix for every expiry caveat 31 | TimePrefix = "time < " 32 | // Gen is a common caveat for every token 33 | Gen = "gen = 1" 34 | ) 35 | 36 | // TokenOptions represent parameters of Token 37 | type TokenOptions struct { 38 | ServerPrivateKey []byte `yaml:"private_key"` 39 | ServerName string `yaml:"server_name"` 40 | UserID string `json:"user_id"` 41 | // The valid period of the token in seconds since its generation. 42 | // Only used in GenerateLoginToken; 0 is treated as defaultDuration. 43 | Duration int 44 | } 45 | 46 | // GenerateLoginToken generates a short term login token to be used as 47 | // token authentication ("m.login.token") 48 | func GenerateLoginToken(op TokenOptions) (string, error) { 49 | if !isValidTokenOptions(op) { 50 | return "", errors.New("The given TokenOptions is invalid") 51 | } 52 | 53 | mac, err := generateBaseMacaroon(op.ServerPrivateKey, op.ServerName, op.UserID) 54 | if err != nil { 55 | return "", err 56 | } 57 | 58 | if op.Duration == 0 { 59 | op.Duration = defaultDuration 60 | } 61 | now := time.Now().Second() 62 | expiryCaveat := TimePrefix + strconv.Itoa(now+op.Duration) 63 | err = mac.AddFirstPartyCaveat([]byte(expiryCaveat)) 64 | if err != nil { 65 | return "", macaroonError(err) 66 | } 67 | 68 | urlSafeEncode, err := serializeMacaroon(*mac) 69 | if err != nil { 70 | return "", macaroonError(err) 71 | } 72 | return urlSafeEncode, nil 73 | } 74 | 75 | // isValidTokenOptions checks for required fields in a TokenOptions 76 | func isValidTokenOptions(op TokenOptions) bool { 77 | if op.ServerPrivateKey == nil || op.ServerName == "" || op.UserID == "" { 78 | return false 79 | } 80 | return true 81 | } 82 | 83 | // generateBaseMacaroon generates a base macaroon common for accessToken & loginToken. 84 | // Returns a macaroon tied with userID, 85 | // returns an error if something goes wrong. 86 | func generateBaseMacaroon( 87 | secret []byte, ServerName string, userID string, 88 | ) (*macaroon.Macaroon, error) { 89 | mac, err := macaroon.New(secret, []byte(userID), ServerName, macaroonVersion) 90 | if err != nil { 91 | return nil, macaroonError(err) 92 | } 93 | 94 | err = mac.AddFirstPartyCaveat([]byte(Gen)) 95 | if err != nil { 96 | return nil, macaroonError(err) 97 | } 98 | 99 | err = mac.AddFirstPartyCaveat([]byte(UserPrefix + userID)) 100 | if err != nil { 101 | return nil, macaroonError(err) 102 | } 103 | 104 | return mac, nil 105 | } 106 | 107 | func macaroonError(err error) error { 108 | return fmt.Errorf("Macaroon creation failed: %s", err.Error()) 109 | } 110 | 111 | // serializeMacaroon takes a macaroon to be serialized. 112 | // returns its base64 encoded string, URL safe, which can be sent via web, email, etc. 113 | func serializeMacaroon(m macaroon.Macaroon) (string, error) { 114 | bin, err := m.MarshalBinary() 115 | if err != nil { 116 | return "", err 117 | } 118 | 119 | urlSafeEncode := base64.RawURLEncoding.EncodeToString(bin) 120 | return urlSafeEncode, nil 121 | } 122 | 123 | // deSerializeMacaroon takes a base64 encoded string of a macaroon to be de-serialized. 124 | // Returns a macaroon. On failure returns error with description. 125 | func deSerializeMacaroon(urlSafeEncode string) (macaroon.Macaroon, error) { 126 | var mac macaroon.Macaroon 127 | bin, err := base64.RawURLEncoding.DecodeString(urlSafeEncode) 128 | if err != nil { 129 | return mac, err 130 | } 131 | 132 | err = mac.UnmarshalBinary(bin) 133 | return mac, err 134 | } 135 | -------------------------------------------------------------------------------- /fclient/dnscache.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | type DNSCache struct { 12 | resolver netResolver 13 | mutex sync.Mutex 14 | size int 15 | duration time.Duration 16 | entries map[string]*dnsCacheEntry 17 | dialer net.Dialer 18 | } 19 | 20 | func NewDNSCache(size int, duration time.Duration, allowNetworks, denyNetworks []string) *DNSCache { 21 | return &DNSCache{ 22 | resolver: net.DefaultResolver, 23 | size: size, 24 | duration: duration, 25 | entries: make(map[string]*dnsCacheEntry), 26 | dialer: net.Dialer{ 27 | ControlContext: allowDenyNetworksControl(allowNetworks, denyNetworks), 28 | }, 29 | } 30 | } 31 | 32 | type dnsCacheEntry struct { 33 | addrs []net.IPAddr 34 | expires time.Time 35 | } 36 | 37 | type netResolver interface { 38 | LookupIPAddr(context.Context, string) ([]net.IPAddr, error) 39 | } 40 | 41 | func (c *DNSCache) lookup(ctx context.Context, name string) (*dnsCacheEntry, bool) { 42 | // Check to see if there's something in the cache for this name. 43 | c.mutex.Lock() 44 | if entry, ok := c.entries[name]; ok { 45 | // Check the expiry of the cache entry. If it's still within 46 | // the expiry period then return the entry as-is. 47 | if time.Now().Before(entry.expires) { 48 | c.mutex.Unlock() 49 | return entry, true 50 | } 51 | 52 | // If it's outside of the validity then remove the entry from 53 | // the cache. 54 | delete(c.entries, name) 55 | } 56 | c.mutex.Unlock() 57 | 58 | // At this point there's either nothing in the cache, or there 59 | // was something in the cache but it's past the validity, so we 60 | // have nuked it. Ask the operating system to perform a lookup 61 | // for us. 62 | 63 | addrs, err := c.resolver.LookupIPAddr(ctx, name) 64 | if err != nil { 65 | return nil, false 66 | } 67 | 68 | c.mutex.Lock() 69 | defer c.mutex.Unlock() 70 | 71 | // If we've hit, or exceed somehow, the maximum size of the cache 72 | // then we will need to evict the oldest entries to make room. 73 | for len(c.entries) >= c.size { 74 | name, ts := "", time.Now().Add(c.duration) 75 | for n, e := range c.entries { 76 | if e.expires.Before(ts) { 77 | ts, name = e.expires, n 78 | } 79 | } 80 | delete(c.entries, name) 81 | } 82 | 83 | // Create a new entry, give it the validity specified when the 84 | // cache was created and then store it. 85 | entry := &dnsCacheEntry{ 86 | addrs: addrs, 87 | expires: time.Now().Add(c.duration), 88 | } 89 | c.entries[name] = entry 90 | 91 | // All good now - return the cache entry. 92 | return entry, false 93 | } 94 | 95 | func (c *DNSCache) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 96 | // Split up the host and port from the give address. 97 | host, port, err := net.SplitHostPort(address) 98 | if err != nil { 99 | return nil, fmt.Errorf("net.SplitHostPort: %w", err) 100 | } 101 | 102 | // On the first attempt, retried will be false. If we try the 103 | // cached entries and none of them connect, we'll retry but with 104 | // retried set to true. This stops us from recursing more than 105 | // once. 106 | retried := false 107 | 108 | retryLookup: 109 | // Consult the cache for the hostname. This will cause the OS to 110 | // ask DNS if needed, updating the cache in the process. 111 | entry, cached := c.lookup(ctx, host) 112 | if entry == nil { 113 | return nil, fmt.Errorf("lookup failed for %q", host) 114 | } 115 | 116 | // Try each address in the cached entry. If we successfully connect 117 | // to one of those addresses then return the conn and stop there. 118 | for _, addr := range entry.addrs { 119 | conn, err := c.dialer.DialContext(ctx, "tcp", addr.String()+":"+port) 120 | if err != nil { 121 | continue 122 | } 123 | return conn, nil 124 | } 125 | 126 | // If we reached this point then we failed to reach any of the 127 | // addresses in the entry. If the entry came from the cache then 128 | // we'll assume that it's no good anymore - delete the cache entry 129 | // and then retry, which will ask the OS to consult DNS again. 130 | if cached && !retried { 131 | retried = true 132 | c.mutex.Lock() 133 | delete(c.entries, host) 134 | c.mutex.Unlock() 135 | goto retryLookup 136 | } 137 | 138 | // All attempts to find a working connection failed from either 139 | // cached entries or from DNS itself. 140 | return nil, fmt.Errorf("connection failed to %q via %d addresses", host, len(entry.addrs)) 141 | } 142 | -------------------------------------------------------------------------------- /fclient/crosssigning_test.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/matrix-org/gomatrixserverlib" 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | ) 9 | 10 | var tests = []struct { 11 | name string 12 | s *CrossSigningKey 13 | other *CrossSigningKey 14 | expect bool 15 | }{ 16 | { 17 | name: "NilReceiver_ReturnsFalse", 18 | s: nil, 19 | other: &CrossSigningKey{}, 20 | expect: false, 21 | }, 22 | { 23 | name: "NilOther_ReturnsFalse", 24 | s: &CrossSigningKey{}, 25 | other: nil, 26 | expect: false, 27 | }, 28 | { 29 | name: "DifferentUserID_ReturnsFalse", 30 | s: &CrossSigningKey{UserID: "user1"}, 31 | other: &CrossSigningKey{UserID: "user2"}, 32 | expect: false, 33 | }, 34 | { 35 | name: "DifferentUsageLength_ReturnsFalse", 36 | s: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeMaster}}, 37 | other: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeMaster, CrossSigningKeyPurposeSelfSigning}}, 38 | expect: false, 39 | }, 40 | { 41 | name: "UnsortedUsages_ReturnsTrue", 42 | s: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeSelfSigning, CrossSigningKeyPurposeMaster}}, 43 | other: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeMaster, CrossSigningKeyPurposeSelfSigning}}, 44 | expect: true, 45 | }, 46 | { 47 | name: "UnsortedUsages_ReturnsTrue", 48 | s: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeSelfSigning, CrossSigningKeyPurposeMaster}}, 49 | other: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeSelfSigning, CrossSigningKeyPurposeMaster}}, 50 | expect: true, 51 | }, 52 | { 53 | name: "DifferentUsageValues_ReturnsFalse", 54 | s: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeMaster}}, 55 | other: &CrossSigningKey{Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeSelfSigning}}, 56 | expect: false, 57 | }, 58 | { 59 | name: "DifferentKeysLength_ReturnsFalse", 60 | s: &CrossSigningKey{Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"key1": {}}}, 61 | other: &CrossSigningKey{Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"key1": {}, "key2": {}}}, 62 | expect: false, 63 | }, 64 | { 65 | name: "DifferentKeysValues_ReturnsFalse", 66 | s: &CrossSigningKey{Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"key1": {}}}, 67 | other: &CrossSigningKey{Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"key1": {1}}}, 68 | expect: false, 69 | }, 70 | { 71 | name: "DifferentSignaturesLength_ReturnsFalse", 72 | s: &CrossSigningKey{Signatures: map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{"sig1": {"key1": {}}}}, 73 | other: &CrossSigningKey{Signatures: map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{"sig1": {"key1": {}}, "sig2": {"key2": {}}}}, 74 | expect: false, 75 | }, 76 | { 77 | name: "DifferentSignaturesValues_ReturnsFalse", 78 | s: &CrossSigningKey{Signatures: map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{"sig1": {"key1": {}}}}, 79 | other: &CrossSigningKey{Signatures: map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{"sig1": {"key1": {1}}}}, 80 | expect: false, 81 | }, 82 | { 83 | name: "IdenticalKeys_ReturnsTrue", 84 | s: &CrossSigningKey{ 85 | UserID: "user1", 86 | Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeMaster}, 87 | Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"key1": {}}, 88 | Signatures: map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{"sig1": {"key1": {}}}, 89 | }, 90 | other: &CrossSigningKey{ 91 | UserID: "user1", 92 | Usage: []CrossSigningKeyPurpose{CrossSigningKeyPurposeMaster}, 93 | Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"key1": {}}, 94 | Signatures: map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{"sig1": {"key1": {}}}, 95 | }, 96 | expect: true, 97 | }, 98 | } 99 | 100 | func TestCrossSigningKeyEqual(t *testing.T) { 101 | for _, tt := range tests { 102 | t.Run(tt.name, func(t *testing.T) { 103 | if got := tt.s.Equal(tt.other); got != tt.expect { 104 | t.Errorf("Equal() = %v, want %v", got, tt.expect) 105 | } 106 | }) 107 | } 108 | } 109 | 110 | func BenchmarkEqual(b *testing.B) { 111 | 112 | for i := 0; i < b.N; i++ { 113 | for _, tt := range tests { 114 | if !tt.s.Equal(tt.other) && tt.expect { 115 | b.Fatal(tt.name, tt.s) 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /stateresolutionv2heaps.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Matrix.org Foundation C.I.C. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gomatrixserverlib 16 | 17 | import ( 18 | "strings" 19 | 20 | "github.com/matrix-org/gomatrixserverlib/spec" 21 | ) 22 | 23 | // A stateResV2ConflictedPowerLevel is used to sort the events by effective 24 | // power level, origin server TS and the lexicographical comparison of event 25 | // IDs. It is a bit of an optimisation to use this - by working out the 26 | // effective power level etc ahead of time, we use less CPU cycles during the 27 | // sort. 28 | type stateResV2ConflictedPowerLevel struct { 29 | powerLevel int64 30 | originServerTS spec.Timestamp 31 | eventID string 32 | event PDU 33 | } 34 | 35 | // A stateResV2ConflictedPowerLevelHeap is used to sort the events using 36 | // sort.Sort or by using the heap functions for further optimisation. Sorting 37 | // ensures that the results are deterministic. 38 | type stateResV2ConflictedPowerLevelHeap []*stateResV2ConflictedPowerLevel 39 | 40 | // Less implements sort.Interface 41 | func sortStateResV2ConflictedPowerLevelHeap(a, b *stateResV2ConflictedPowerLevel) int { 42 | // Try to tiebreak on the effective power level 43 | if a.powerLevel > b.powerLevel { 44 | return -1 45 | } 46 | if a.powerLevel < b.powerLevel { 47 | return 1 48 | } 49 | // If we've reached here then s[i].powerLevel == s[j].powerLevel 50 | // so instead try to tiebreak on origin server TS 51 | if a.originServerTS < b.originServerTS { 52 | return -1 53 | } 54 | if a.originServerTS > b.originServerTS { 55 | return 1 56 | } 57 | // If we've reached here then s[i].originServerTS == s[j].originServerTS 58 | // so instead try to tiebreak on a lexicographical comparison of the event ID 59 | return strings.Compare(a.eventID[:], b.eventID[:]) 60 | } 61 | 62 | // Push implements heap.Interface 63 | func (s *stateResV2ConflictedPowerLevelHeap) Push(x *stateResV2ConflictedPowerLevel) { 64 | *s = append(*s, x) 65 | } 66 | 67 | // Pop implements heap.Interface 68 | func (s *stateResV2ConflictedPowerLevelHeap) Pop() *stateResV2ConflictedPowerLevel { 69 | old := *s 70 | n := len(old) 71 | x := old[n-1] 72 | *s = old[:n-1] 73 | return x 74 | } 75 | 76 | // A stateResV2ConflictedOther is used to sort the events by power level 77 | // mainline positions, origin server TS and the lexicographical comparison of 78 | // event IDs. It is a bit of an optimisation to use this - by working out the 79 | // effective power level etc ahead of time, we use less CPU cycles during the 80 | // sort. 81 | type stateResV2ConflictedOther struct { 82 | mainlinePosition int 83 | mainlineSteps int 84 | originServerTS spec.Timestamp 85 | eventID string 86 | event PDU 87 | } 88 | 89 | // A stateResV2ConflictedOtherHeap is used to sort the events using 90 | // sort.Sort or by using the heap functions for further optimisation. Sorting 91 | // ensures that the results are deterministic. 92 | type stateResV2ConflictedOtherHeap []*stateResV2ConflictedOther 93 | 94 | func sortStateResV2ConflictedOtherHeap(a, b *stateResV2ConflictedOther) int { 95 | // Try to tiebreak on the mainline position 96 | if a.mainlinePosition < b.mainlinePosition { 97 | return -1 98 | } 99 | if a.mainlinePosition > b.mainlinePosition { 100 | return 1 101 | } 102 | // If we've reached here then s[i].mainlinePosition == s[j].mainlinePosition 103 | // so instead try to tiebreak on step count 104 | if a.mainlineSteps < b.mainlineSteps { 105 | return -1 106 | } 107 | if a.mainlineSteps > b.mainlineSteps { 108 | return 1 109 | } 110 | // If we've reached here then s[i].mainlineSteps == s[j].mainlineSteps 111 | // so instead try to tiebreak on origin server TS 112 | if a.originServerTS < b.originServerTS { 113 | return -1 114 | } 115 | if a.originServerTS > b.originServerTS { 116 | return 1 117 | } 118 | // If we've reached here then s[i].originServerTS == s[j].originServerTS 119 | // so instead try to tiebreak on a lexicographical comparison of the event ID 120 | return strings.Compare(a.eventID, b.eventID) 121 | } 122 | 123 | // Push implements heap.Interface 124 | func (s *stateResV2ConflictedOtherHeap) Push(x *stateResV2ConflictedOther) { 125 | *s = append(*s, x) 126 | } 127 | 128 | // Pop implements heap.Interface 129 | func (s *stateResV2ConflictedOtherHeap) Pop() *stateResV2ConflictedOther { 130 | old := *s 131 | n := len(old) 132 | x := old[n-1] 133 | *s = old[:n-1] 134 | return x 135 | } 136 | -------------------------------------------------------------------------------- /fclient/federationtypes_test.go: -------------------------------------------------------------------------------- 1 | package fclient 2 | 3 | import ( 4 | "encoding/json" 5 | "strings" 6 | "testing" 7 | "unicode" 8 | 9 | "github.com/google/go-cmp/cmp" 10 | "github.com/matrix-org/gomatrixserverlib/spec" 11 | ) 12 | 13 | const emptyRespStateResponse = `{"pdus":[],"auth_chain":[]}` 14 | const emptyRespSendJoinResponse = `{"state":[],"auth_chain":[],"origin":""}` 15 | 16 | func TestParseServerName(t *testing.T) { 17 | validTests := map[string][]interface{}{ 18 | "www.example.org:1234": {"www.example.org", 1234}, 19 | "www.example.org": {"www.example.org", -1}, 20 | "1234.example.com": {"1234.example.com", -1}, 21 | "1.1.1.1:1234": {"1.1.1.1", 1234}, 22 | "1.1.1.1": {"1.1.1.1", -1}, 23 | "[1fff:0:a88:85a3::ac1f]:1234": {"[1fff:0:a88:85a3::ac1f]", 1234}, 24 | "[2001:0db8::ff00:0042]": {"[2001:0db8::ff00:0042]", -1}, 25 | } 26 | 27 | for input, output := range validTests { 28 | host, port, isValid := spec.ParseAndValidateServerName(spec.ServerName(input)) 29 | if !isValid { 30 | t.Errorf("Expected serverName '%s' to be parsed as valid, but was not", input) 31 | } 32 | 33 | if host != output[0] || port != output[1].(int) { 34 | t.Errorf( 35 | "Expected serverName '%s' to be cleaned and validated to '%s', %d, got '%s', %d", 36 | input, output[0], output[1], host, port, 37 | ) 38 | } 39 | } 40 | 41 | invalidTests := []string{ 42 | // ipv6 not in square brackets 43 | "2001:0db8::ff00:0042", 44 | 45 | // host with invalid characters 46 | "test_test.com", 47 | 48 | // ipv6 with insufficient parts 49 | "[2001:0db8:0000:0000:0000:ff00:0042]", 50 | 51 | // empty host 52 | ":8080", 53 | 54 | // empty string 55 | "", 56 | } 57 | 58 | for _, input := range invalidTests { 59 | _, _, isValid := spec.ParseAndValidateServerName(spec.ServerName(input)) 60 | if isValid { 61 | t.Errorf("Expected serverName '%s' to be rejected but was accepted", input) 62 | } 63 | } 64 | } 65 | 66 | func TestRespStateMarshalJSON(t *testing.T) { 67 | inputData := `{"pdus":[],"auth_chain":[]}` 68 | var input RespState 69 | if err := json.Unmarshal([]byte(inputData), &input); err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | gotBytes, err := json.Marshal(input) 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | 78 | got := string(gotBytes) 79 | 80 | if emptyRespStateResponse != got { 81 | t.Errorf("json.Marshal(RespState(%q)): wanted %q, got %q", inputData, emptyRespStateResponse, got) 82 | } 83 | } 84 | 85 | func TestRespStateUnmarshalJSON(t *testing.T) { 86 | inputData := `{"pdus":[],"auth_chain":[]}` 87 | var input RespState 88 | if err := json.Unmarshal([]byte(inputData), &input); err != nil { 89 | t.Fatal(err) 90 | } 91 | 92 | gotBytes, err := json.Marshal(input) 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | got := string(gotBytes) 97 | 98 | if emptyRespStateResponse != got { 99 | t.Errorf("json.Marshal(RespSendJoin(%q)): wanted %q, got %q", inputData, emptyRespStateResponse, got) 100 | } 101 | } 102 | 103 | func TestRespSendJoinMarshalJSON(t *testing.T) { 104 | // we unmarshall and marshall an empty send-join response, and check it round-trips correctly. 105 | inputData := `{"state":[],"auth_chain":[],"origin":""}` 106 | var input RespSendJoin 107 | if err := json.Unmarshal([]byte(inputData), &input); err != nil { 108 | t.Fatal(err) 109 | } 110 | 111 | want := RespSendJoin{ 112 | StateEvents: []spec.RawJSON{}, 113 | AuthEvents: []spec.RawJSON{}, 114 | Origin: "", 115 | } 116 | if !cmp.Equal(input, want, cmp.AllowUnexported(RespSendJoin{})) { 117 | t.Errorf("json.Unmarshal(%s): wanted %+v, got %+v", inputData, want, input) 118 | } 119 | 120 | gotBytes, err := json.Marshal(input) 121 | if err != nil { 122 | t.Fatal(err) 123 | } 124 | got := string(gotBytes) 125 | if emptyRespSendJoinResponse != got { 126 | t.Errorf("json.Marshal(%+v): wanted '%s', got '%s'", input, emptyRespSendJoinResponse, got) 127 | } 128 | } 129 | 130 | func TestRespSendJoinMarshalJSONPartialState(t *testing.T) { 131 | inputData := `{ 132 | "state":[],"auth_chain":[],"origin":"o1", 133 | "members_omitted":true, 134 | "servers_in_room":["s1", "s2"] 135 | }` 136 | 137 | var input RespSendJoin 138 | if err := json.Unmarshal([]byte(inputData), &input); err != nil { 139 | t.Fatal(err) 140 | } 141 | 142 | want := RespSendJoin{ 143 | StateEvents: []spec.RawJSON{}, 144 | AuthEvents: []spec.RawJSON{}, 145 | Origin: "o1", 146 | MembersOmitted: true, 147 | ServersInRoom: []string{"s1", "s2"}, 148 | } 149 | if !cmp.Equal(input, want, cmp.AllowUnexported(RespSendJoin{})) { 150 | t.Errorf("json.Unmarshal(%s): wanted %+v, got %+v", inputData, want, input) 151 | } 152 | 153 | gotBytes, err := json.Marshal(input) 154 | if err != nil { 155 | t.Fatal(err) 156 | } 157 | got := string(gotBytes) 158 | // the result should be the input, with spaces removed 159 | wantJSON := strings.Map(func(r rune) rune { 160 | if unicode.IsSpace(r) { 161 | return -1 162 | } 163 | return r 164 | }, inputData) 165 | if wantJSON != got { 166 | t.Errorf("json.Marshal(%+v):\n wanted: '%s'\n got: '%s'", input, wantJSON, got) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /eventV2_test.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/matrix-org/gomatrixserverlib/spec" 10 | "github.com/stretchr/testify/assert" 11 | "golang.org/x/crypto/ed25519" 12 | ) 13 | 14 | func TestCheckFields(t *testing.T) { 15 | roomID := "!room:localhost" 16 | senderID := "@sender:localhost" 17 | tooLargeStateKey := strings.Repeat("ä", 150) 18 | tooLongStateKey := strings.Repeat("b", 256) 19 | 20 | tests := []struct { 21 | name string 22 | input ProtoEvent 23 | wantErr assert.ErrorAssertionFunc 24 | wantPersistable bool 25 | }{ 26 | 27 | { 28 | name: "fail due to invalid roomID", 29 | input: ProtoEvent{ 30 | SenderID: senderID, 31 | RoomID: "@invalid:room", 32 | PrevEvents: []string{}, 33 | AuthEvents: []string{}, 34 | Content: spec.RawJSON("{}"), 35 | Unsigned: spec.RawJSON("{}"), 36 | }, 37 | wantErr: assert.Error, 38 | }, 39 | { 40 | name: "fail due to event size", 41 | input: ProtoEvent{ 42 | SenderID: senderID, 43 | RoomID: roomID, 44 | PrevEvents: []string{}, 45 | AuthEvents: []string{}, 46 | Content: spec.RawJSON(fmt.Sprintf(`{"data":"%s"}`, strings.Repeat("x", maxEventLength))), 47 | Unsigned: spec.RawJSON("{}"), 48 | }, 49 | wantErr: assert.Error, 50 | }, 51 | { 52 | name: "fail due to senderID too long", 53 | input: ProtoEvent{ 54 | SenderID: fmt.Sprintf("@%s:localhost", strings.Repeat("a", 255)), 55 | RoomID: roomID, 56 | PrevEvents: []string{}, 57 | AuthEvents: []string{}, 58 | Content: spec.RawJSON("{}"), 59 | Unsigned: spec.RawJSON("{}"), 60 | }, 61 | wantErr: assert.Error, 62 | }, 63 | { 64 | name: "successfully check fields", 65 | input: ProtoEvent{ 66 | SenderID: senderID, 67 | RoomID: roomID, 68 | PrevEvents: []string{}, 69 | AuthEvents: []string{}, 70 | Content: spec.RawJSON("{}"), 71 | Unsigned: spec.RawJSON("{}"), 72 | }, 73 | wantErr: assert.NoError, 74 | }, { 75 | name: "fail due to senderID too large", 76 | input: ProtoEvent{ 77 | SenderID: fmt.Sprintf("@%s:localhost", strings.Repeat("ä", 200)), 78 | RoomID: roomID, 79 | PrevEvents: []string{}, 80 | AuthEvents: []string{}, 81 | Content: spec.RawJSON("{}"), 82 | Unsigned: spec.RawJSON("{}"), 83 | }, 84 | wantErr: assert.Error, 85 | wantPersistable: true, 86 | }, 87 | { 88 | name: "fail due to type too large", 89 | input: ProtoEvent{ 90 | SenderID: fmt.Sprintf("@%s:localhost", strings.Repeat("ä", 10)), 91 | Type: strings.Repeat("ä", 150), 92 | RoomID: roomID, 93 | PrevEvents: []string{}, 94 | AuthEvents: []string{}, 95 | Content: spec.RawJSON("{}"), 96 | Unsigned: spec.RawJSON("{}"), 97 | }, 98 | wantErr: assert.Error, 99 | wantPersistable: true, 100 | }, 101 | { 102 | name: "fail due to type too long", 103 | input: ProtoEvent{ 104 | SenderID: fmt.Sprintf("@%s:localhost", strings.Repeat("ä", 10)), 105 | Type: strings.Repeat("b", 256), 106 | RoomID: roomID, 107 | PrevEvents: []string{}, 108 | AuthEvents: []string{}, 109 | Content: spec.RawJSON("{}"), 110 | Unsigned: spec.RawJSON("{}"), 111 | }, 112 | wantErr: assert.Error, 113 | wantPersistable: false, 114 | }, 115 | { 116 | name: "fail due to state_key too large", 117 | input: ProtoEvent{ 118 | SenderID: fmt.Sprintf("@%s:localhost", strings.Repeat("ä", 10)), 119 | StateKey: &tooLargeStateKey, 120 | RoomID: roomID, 121 | PrevEvents: []string{}, 122 | AuthEvents: []string{}, 123 | Content: spec.RawJSON("{}"), 124 | Unsigned: spec.RawJSON("{}"), 125 | }, 126 | wantErr: assert.Error, 127 | wantPersistable: true, 128 | }, 129 | { 130 | name: "fail due to state_key too long", 131 | input: ProtoEvent{ 132 | SenderID: fmt.Sprintf("@%s:localhost", strings.Repeat("ä", 10)), 133 | StateKey: &tooLongStateKey, 134 | RoomID: roomID, 135 | PrevEvents: []string{}, 136 | AuthEvents: []string{}, 137 | Content: spec.RawJSON("{}"), 138 | Unsigned: spec.RawJSON("{}"), 139 | }, 140 | wantErr: assert.Error, 141 | wantPersistable: false, 142 | }, 143 | } 144 | _, sk, err := ed25519.GenerateKey(nil) 145 | assert.NoError(t, err) 146 | for _, tt := range tests { 147 | t.Run(tt.name, func(t *testing.T) { 148 | for roomVersion := range roomVersionMeta { 149 | if roomVersion == RoomVersionPseudoIDs { 150 | continue 151 | } 152 | t.Run(tt.name+"-"+string(roomVersion), func(t *testing.T) { 153 | ev, err := MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&tt.input).Build(time.Now(), "localhost", "ed25519:1", sk) 154 | tt.wantErr(t, err) 155 | if ev != nil { 156 | err = CheckFields(ev) 157 | tt.wantErr(t, err, fmt.Sprintf("CheckFields(%v)", tt.input)) 158 | t.Logf("%v", err) 159 | } 160 | switch e := err.(type) { 161 | case EventValidationError: 162 | assert.Equalf(t, tt.wantPersistable, e.Persistable, "unexpected persistable") 163 | } 164 | }) 165 | 166 | } 167 | 168 | }) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /signing.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import ( 19 | "encoding/json" 20 | "fmt" 21 | 22 | "github.com/matrix-org/gomatrixserverlib/spec" 23 | "github.com/tidwall/sjson" 24 | "golang.org/x/crypto/ed25519" 25 | ) 26 | 27 | // A KeyID is the ID of a ed25519 key used to sign JSON. 28 | // The key IDs have a format of "ed25519:[0-9A-Za-z]+" 29 | // If we switch to using a different signing algorithm then we will change the 30 | // prefix used. 31 | type KeyID string 32 | 33 | // SignJSON signs a JSON object returning a copy signed with the given key. 34 | // https://matrix.org/docs/spec/server_server/unstable.html#signing-json 35 | func SignJSON(signingName string, keyID KeyID, privateKey ed25519.PrivateKey, message []byte) (signed []byte, err error) { 36 | preserve := struct { 37 | Signatures map[string]map[KeyID]spec.Base64Bytes `json:"signatures"` 38 | Unsigned spec.RawJSON `json:"unsigned"` 39 | }{ 40 | Signatures: map[string]map[KeyID]spec.Base64Bytes{}, 41 | } 42 | if err = json.Unmarshal(message, &preserve); err != nil { 43 | return nil, err 44 | } 45 | if message, err = sjson.DeleteBytes(message, "signatures"); err != nil { 46 | return nil, err 47 | } 48 | if message, err = sjson.DeleteBytes(message, "unsigned"); err != nil { 49 | return nil, err 50 | } 51 | canonical, err := CanonicalJSON(message) 52 | if err != nil { 53 | return nil, err 54 | } 55 | signature := spec.Base64Bytes(ed25519.Sign(privateKey, canonical)) 56 | if _, ok := preserve.Signatures[signingName]; ok { 57 | preserve.Signatures[signingName][keyID] = signature 58 | } else { 59 | preserve.Signatures[signingName] = map[KeyID]spec.Base64Bytes{ 60 | keyID: signature, 61 | } 62 | } 63 | signatures, err := json.Marshal(preserve.Signatures) 64 | if err != nil { 65 | return nil, err 66 | } 67 | if signed, err = sjson.SetRawBytes(canonical, "signatures", signatures); err != nil { 68 | return nil, err 69 | } 70 | if len(preserve.Unsigned) > 0 { 71 | if signed, err = sjson.SetRawBytes(signed, "unsigned", preserve.Unsigned); err != nil { 72 | return nil, err 73 | } 74 | } 75 | if signed, err = CanonicalJSON(signed); err != nil { 76 | return nil, err 77 | } 78 | return 79 | } 80 | 81 | // ListKeyIDs lists the key IDs a given entity has signed a message with. 82 | func ListKeyIDs(signingName string, message []byte) ([]KeyID, error) { 83 | var object struct { 84 | Signatures map[string]map[KeyID]json.RawMessage `json:"signatures"` 85 | } 86 | if err := json.Unmarshal(message, &object); err != nil { 87 | return nil, err 88 | } 89 | var result []KeyID 90 | for keyID := range object.Signatures[signingName] { 91 | result = append(result, keyID) 92 | } 93 | return result, nil 94 | } 95 | 96 | // VerifyJSON checks that the entity has signed the message using a particular key. 97 | func VerifyJSON(signingName string, keyID KeyID, publicKey ed25519.PublicKey, message []byte) error { 98 | // Unpack the top-level key of the JSON object without unpacking the contents of the keys. 99 | // This allows us to add and remove the top-level keys from the JSON object. 100 | // It also ensures that the JSON is actually a valid JSON object. 101 | var object map[string]*json.RawMessage 102 | var signatures map[string]map[KeyID]spec.Base64Bytes 103 | if err := json.Unmarshal(message, &object); err != nil { 104 | return err 105 | } 106 | 107 | // Check that there is a signature from the entity that we are expecting a signature from. 108 | if object["signatures"] == nil { 109 | return fmt.Errorf("No signatures") 110 | } 111 | if err := json.Unmarshal(*object["signatures"], &signatures); err != nil { 112 | return err 113 | } 114 | signature, ok := signatures[signingName][keyID] 115 | if !ok { 116 | return fmt.Errorf("No signature from %q with ID %q", signingName, keyID) 117 | } 118 | if len(signature) != ed25519.SignatureSize { 119 | return fmt.Errorf("Bad signature length from %q with ID %q", signingName, keyID) 120 | } 121 | 122 | // The "unsigned" key and "signatures" keys aren't covered by the signature so remove them. 123 | delete(object, "unsigned") 124 | delete(object, "signatures") 125 | 126 | // Encode the JSON without the "unsigned" and "signatures" keys in the canonical format. 127 | unsorted, err := json.Marshal(object) 128 | if err != nil { 129 | return err 130 | } 131 | canonical, err := CanonicalJSON(unsorted) 132 | if err != nil { 133 | return err 134 | } 135 | 136 | // Verify the ed25519 signature. 137 | if !ed25519.Verify(publicKey, canonical, signature) { 138 | return fmt.Errorf("Bad signature from %q with ID %q", signingName, keyID) 139 | } 140 | 141 | return nil 142 | } 143 | -------------------------------------------------------------------------------- /pdu.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/matrix-org/gomatrixserverlib/spec" 9 | "golang.org/x/crypto/ed25519" 10 | ) 11 | 12 | type PDU interface { 13 | EventID() string 14 | StateKey() *string 15 | StateKeyEquals(s string) bool 16 | Type() string 17 | Content() []byte 18 | // JoinRule returns the value of the content.join_rule field if this event 19 | // is an "m.room.join_rules" event. 20 | // Returns an error if the event is not a m.room.join_rules event or if the content 21 | // is not valid m.room.join_rules content. 22 | JoinRule() (string, error) 23 | // HistoryVisibility returns the value of the content.history_visibility field if this event 24 | // is an "m.room.history_visibility" event. 25 | // Returns an error if the event is not a m.room.history_visibility event or if the content 26 | // is not valid m.room.history_visibility content. 27 | HistoryVisibility() (HistoryVisibility, error) 28 | Membership() (string, error) 29 | PowerLevels() (*PowerLevelContent, error) 30 | Version() RoomVersion 31 | RoomID() spec.RoomID 32 | Redacts() string 33 | // Redacted returns whether the event is redacted. 34 | Redacted() bool 35 | PrevEventIDs() []string 36 | OriginServerTS() spec.Timestamp 37 | // Redact redacts the event. 38 | Redact() 39 | SenderID() spec.SenderID 40 | Unsigned() []byte 41 | // SetUnsigned sets the unsigned key of the event. 42 | // Returns a copy of the event with the "unsigned" key set. 43 | SetUnsigned(unsigned interface{}) (PDU, error) 44 | // SetUnsignedField takes a path and value to insert into the unsigned dict of 45 | // the event. 46 | // path is a dot separated path into the unsigned dict (see gjson package 47 | // for details on format). In particular some characters like '.' and '*' must 48 | // be escaped. 49 | SetUnsignedField(path string, value interface{}) error 50 | // Sign returns a copy of the event with an additional signature. 51 | Sign(signingName string, keyID KeyID, privateKey ed25519.PrivateKey) PDU 52 | Depth() int64 // TODO: remove 53 | JSON() []byte // TODO: remove 54 | AuthEventIDs() []string // TODO: remove 55 | ToHeaderedJSON() ([]byte, error) // TODO: remove 56 | // IsSticky returns true if the event is *currently* considered "sticky" given the received time. 57 | // Sticky events are annotated as sticky and carry strong delivery guarantees to clients (and 58 | // therefore servers). `received` should be specified as the time the event was received by the 59 | // server if, and only if, the event was received over `/send`. Otherwise, `received` should be 60 | // `time.Now()`. Returns false if the event is not sticky, or no longer sticky. `now` can be supplied 61 | // to override the current time. 62 | IsSticky(now time.Time, received time.Time) bool 63 | // StickyEndTime returns the time at which the event is no longer considered "sticky". See `IsSticky` 64 | // for details on sticky events. Returns `time.Time{}` (zero) if the event is not a sticky event. 65 | StickyEndTime(received time.Time) time.Time 66 | } 67 | 68 | // Convert a slice of concrete PDU implementations to a slice of PDUs. This is useful when 69 | // interfacing with GMSL functions which require []PDU. 70 | func ToPDUs[T PDU](events []T) []PDU { 71 | result := make([]PDU, len(events)) 72 | for i := range events { 73 | result[i] = events[i] 74 | } 75 | return result 76 | } 77 | 78 | // A StateKeyTuple is the combination of an event type and an event state key. 79 | // It is often used as a key in maps. 80 | type StateKeyTuple struct { 81 | // The "type" key of a matrix event. 82 | EventType string 83 | // The "state_key" of a matrix event. 84 | // The empty string is a legitimate value for the "state_key" in matrix 85 | // so take care to initialise this field lest you accidentally request a 86 | // "state_key" with the go default of the empty string. 87 | StateKey string 88 | } 89 | 90 | // An eventReference is a reference to a matrix event. 91 | type eventReference struct { 92 | // The event ID of the event. 93 | EventID string 94 | // The sha256 of the redacted event. 95 | EventSHA256 spec.Base64Bytes 96 | } 97 | 98 | // UnmarshalJSON implements json.Unmarshaller 99 | func (er *eventReference) UnmarshalJSON(data []byte) error { 100 | var tuple []spec.RawJSON 101 | if err := json.Unmarshal(data, &tuple); err != nil { 102 | return err 103 | } 104 | if len(tuple) != 2 { 105 | return fmt.Errorf("gomatrixserverlib: invalid event reference, invalid length: %d != 2", len(tuple)) 106 | } 107 | if err := json.Unmarshal(tuple[0], &er.EventID); err != nil { 108 | return fmt.Errorf("gomatrixserverlib: invalid event reference, first element is invalid: %q %v", string(tuple[0]), err) 109 | } 110 | var hashes struct { 111 | SHA256 spec.Base64Bytes `json:"sha256"` 112 | } 113 | if err := json.Unmarshal(tuple[1], &hashes); err != nil { 114 | return fmt.Errorf("gomatrixserverlib: invalid event reference, second element is invalid: %q %v", string(tuple[1]), err) 115 | } 116 | er.EventSHA256 = hashes.SHA256 117 | return nil 118 | } 119 | 120 | // MarshalJSON implements json.Marshaller 121 | func (er eventReference) MarshalJSON() ([]byte, error) { 122 | hashes := struct { 123 | SHA256 spec.Base64Bytes `json:"sha256"` 124 | }{er.EventSHA256} 125 | 126 | tuple := []interface{}{er.EventID, hashes} 127 | 128 | return json.Marshal(&tuple) 129 | } 130 | -------------------------------------------------------------------------------- /eventV1_test.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/base64" 6 | "encoding/json" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/tidwall/sjson" 12 | ) 13 | 14 | func makeStickyEvent(t *testing.T, durationMS int64, originTS int64, stateKey *string) PDU { 15 | verImpl := MustGetRoomVersion(RoomVersionV12) 16 | 17 | m := map[string]interface{}{ 18 | "sticky": map[string]int64{ 19 | "duration_ms": durationMS, 20 | }, 21 | "room_id": "!L6nFTAu28CEi9yn9up1SUiKtTNnKt2yomgy2JFRT2Zk", 22 | "type": "m.room.message", 23 | "sender": "@user:localhost", 24 | "content": map[string]interface{}{ 25 | "body": "Hello, World!", 26 | "msgtype": "m.text", 27 | }, 28 | "origin_server_ts": originTS, 29 | "unsigned": make(map[string]interface{}), 30 | "depth": 1, 31 | "origin": "localhost", 32 | "prev_events": []string{"$65vISquU7WNlFCaJeJ5uohlX4LVEPx5yEkAc1hpRf44"}, 33 | "auth_events": []string{"$65vISquU7WNlFCaJeJ5uohlX4LVEPx5yEkAc1hpRf44"}, 34 | "hashes": map[string]string{ 35 | "sha256": "1234567890", 36 | }, 37 | "signatures": map[string]interface{}{ 38 | "localhost": map[string]string{ 39 | "ed25519:localhost": "doesn't matter because it's not checked", 40 | }, 41 | }, 42 | } 43 | if stateKey != nil { 44 | m["state_key"] = *stateKey 45 | } 46 | if durationMS < 0 { 47 | delete(m, "sticky") 48 | } 49 | 50 | b, err := json.Marshal(m) 51 | assert.NoError(t, err, "failed to marshal sticky message event") 52 | 53 | // we need to add hashes manually so we don't cause our event to become redacted 54 | cj, err := CanonicalJSON(b) 55 | assert.NoError(t, err, "failed to canonicalize sticky message event") 56 | for _, key := range []string{"signatures", "unsigned", "hashes"} { 57 | cj, err = sjson.DeleteBytes(cj, key) 58 | assert.NoErrorf(t, err, "failed to delete %s from sticky message event", key) 59 | } 60 | sum := sha256.Sum256(cj) 61 | b, err = sjson.SetBytes(b, "hashes.sha256", base64.RawURLEncoding.EncodeToString(sum[:])) 62 | assert.NoError(t, err, "failed to set sha256 hash on sticky message event") 63 | 64 | ev, err := verImpl.NewEventFromUntrustedJSON(b) 65 | assert.NoError(t, err, "failed to create new untrusted sticky message event") 66 | assert.NotNil(t, ev) 67 | return ev 68 | } 69 | 70 | func TestIsSticky(t *testing.T) { 71 | now := time.Now() 72 | 73 | // Happy path 74 | ev := makeStickyEvent(t, 20000, now.UnixMilli(), nil) 75 | assert.True(t, ev.IsSticky(now, now)) 76 | 77 | // Origin before now 78 | ev = makeStickyEvent(t, 20000, now.UnixMilli()-10000, nil) 79 | assert.True(t, ev.IsSticky(now, now)) // should use the -10s time from origin as the start time 80 | 81 | // Origin in the future 82 | ev = makeStickyEvent(t, 20000, now.UnixMilli()+30000, nil) 83 | assert.True(t, ev.IsSticky(now, now)) // This will switch to using Now() instead of the 30s future, so should be in range 84 | 85 | // Origin is well before now, leading to expiration upon receipt 86 | ev = makeStickyEvent(t, 20000, now.UnixMilli()-30000, nil) 87 | assert.False(t, ev.IsSticky(now, now)) 88 | 89 | // State events can also be sticky 90 | stateKey := "state_key" 91 | ev = makeStickyEvent(t, 20000, now.UnixMilli(), &stateKey) 92 | assert.True(t, ev.IsSticky(now, now)) 93 | 94 | // Not a sticky event 95 | ev = makeStickyEvent(t, -1, now.UnixMilli(), nil) // -1 creates a non-sticky event 96 | assert.False(t, ev.IsSticky(now, now)) 97 | } 98 | 99 | func TestStickyEndTime(t *testing.T) { 100 | now := time.Now().UTC().Truncate(time.Millisecond) 101 | nowTS := now.UnixMilli() 102 | received := now 103 | 104 | // Happy path: event is a message event, and origin and duration are within range 105 | ev := makeStickyEvent(t, 20000, nowTS, nil) 106 | assert.Equal(t, now.Add(20*time.Second), ev.StickyEndTime(received)) 107 | 108 | // Origin before now, but duration still within range 109 | ev = makeStickyEvent(t, 20000, nowTS-10000, nil) 110 | assert.Equal(t, now.Add(10*time.Second), ev.StickyEndTime(received)) // +10 s because origin is -10s with a duration of 20s 111 | 112 | // Origin and duration before now 113 | ev = makeStickyEvent(t, 20000, nowTS-30000, nil) 114 | assert.Equal(t, received.Add(-10*time.Second), ev.StickyEndTime(received)) // 10s before received (-30+20 = -10) 115 | 116 | // Origin in the future (using received time instead), duration still within range 117 | ev = makeStickyEvent(t, 20000, nowTS+10000, nil) 118 | assert.Equal(t, now.Add(20*time.Second), ev.StickyEndTime(received)) // +20s because we'll use the received time as a start time 119 | 120 | // Origin is in the future, which places the start time before the origin 121 | ev = makeStickyEvent(t, 20000, nowTS+30000, nil) 122 | assert.Equal(t, received.Add(20*time.Second), ev.StickyEndTime(received)) // The origin is ignored, so +20s for the duration 123 | 124 | // Duration is more than an hour 125 | ev = makeStickyEvent(t, 3699999, nowTS, nil) 126 | assert.Equal(t, now.Add(1*time.Hour), ev.StickyEndTime(received)) 127 | 128 | // State events can also be sticky 129 | stateKey := "state_key" 130 | ev = makeStickyEvent(t, 20000, nowTS, &stateKey) 131 | assert.Equal(t, now.Add(20*time.Second), ev.StickyEndTime(received)) 132 | 133 | // Not a sticky event 134 | ev = makeStickyEvent(t, -1, nowTS, nil) // -1 creates a non-sticky event 135 | assert.Equal(t, time.Time{}, ev.StickyEndTime(received)) 136 | } 137 | -------------------------------------------------------------------------------- /load.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/matrix-org/gomatrixserverlib/spec" 10 | ) 11 | 12 | // EventLoadResult is the result of loading and verifying an event in the EventsLoader. 13 | type EventLoadResult struct { 14 | Event PDU 15 | Error error 16 | SoftFail bool 17 | } 18 | 19 | // EventsLoader loads untrusted events and verifies them. 20 | type EventsLoader struct { 21 | roomVer RoomVersion 22 | keyRing JSONVerifier 23 | provider EventProvider 24 | stateProvider StateProvider 25 | // Set to true to do: 26 | // 6. Passes authorization rules based on the current state of the room, otherwise it is "soft failed". 27 | // This is only desirable for live events, not backfilled events hence the flag. 28 | performSoftFailCheck bool 29 | } 30 | 31 | // NewEventsLoader returns a new events loader 32 | func NewEventsLoader(roomVer RoomVersion, keyRing JSONVerifier, stateProvider StateProvider, provider EventProvider, performSoftFailCheck bool) *EventsLoader { 33 | return &EventsLoader{ 34 | roomVer: roomVer, 35 | keyRing: keyRing, 36 | provider: provider, 37 | stateProvider: stateProvider, 38 | performSoftFailCheck: performSoftFailCheck, 39 | } 40 | } 41 | 42 | // LoadAndVerify loads untrusted events and verifies them. 43 | // Checks performed are outlined at https://matrix.org/docs/spec/server_server/latest#checks-performed-on-receipt-of-a-pdu 44 | // The length of the returned slice will always equal the length of rawEvents. 45 | // The order of the returned events depends on `sortOrder`. The events are reverse topologically sorted by the ordering specified. However, 46 | // in order to sort the events must be loaded which could fail. For those events which fail to be loaded, they will 47 | // be put at the end of the returned slice. 48 | func (l *EventsLoader) LoadAndVerify(ctx context.Context, rawEvents []json.RawMessage, sortOrder TopologicalOrder, userIDForSender spec.UserIDForSender) ([]EventLoadResult, error) { 49 | results := make([]EventLoadResult, len(rawEvents)) 50 | 51 | verImpl, err := GetRoomVersion(l.roomVer) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | // 1. Is a valid event, otherwise it is dropped. 57 | // 3. Passes hash checks, otherwise it is redacted before being processed further. 58 | events := make([]PDU, 0, len(rawEvents)) 59 | errs := make([]error, 0, len(rawEvents)) 60 | for _, rawEv := range rawEvents { 61 | event, err := verImpl.NewEventFromUntrustedJSON(rawEv) 62 | if err != nil { 63 | errs = append(errs, err) 64 | continue 65 | } 66 | events = append(events, event) 67 | } 68 | 69 | events = ReverseTopologicalOrdering(events, sortOrder) 70 | // assign the errors to the end of the slice 71 | for i := 0; i < len(errs); i++ { 72 | results[len(results)-len(errs)+i] = EventLoadResult{ 73 | Error: errs[i], 74 | } 75 | } 76 | // at this point, the three slices look something like: 77 | // results: [ _ , _ , _ , err1 , err2 ] 78 | // errs: [ err1, err2 ] 79 | // events [ ev1, ev2, ev3 ] 80 | // so we can directly index from events into results from now on. 81 | 82 | // 2. Passes signature checks, otherwise it is dropped. 83 | failures := VerifyAllEventSignatures(ctx, events, l.keyRing, userIDForSender) 84 | if len(failures) != len(events) { 85 | return nil, fmt.Errorf("gomatrixserverlib: bulk event signature verification length mismatch: %d != %d", len(failures), len(events)) 86 | } 87 | for i := range events { 88 | h := events[i] 89 | results[i] = EventLoadResult{ 90 | Event: h, 91 | } 92 | if eventErr := failures[i]; eventErr != nil { 93 | if results[i].Error == nil { // could have failed earlier 94 | results[i].Error = SignatureErr{eventErr} 95 | continue 96 | } 97 | } 98 | // 4. Passes authorization rules based on the event's auth events, otherwise it is rejected. 99 | if err := VerifyEventAuthChain(ctx, h, l.provider, userIDForSender); err != nil { 100 | if results[i].Error == nil { // could have failed earlier 101 | results[i].Error = AuthChainErr{err} 102 | continue 103 | } 104 | } 105 | 106 | // 5. Passes authorization rules based on the state at the event, otherwise it is rejected. 107 | if err := VerifyAuthRulesAtState(ctx, l.stateProvider, h, true, userIDForSender); err != nil { 108 | if results[i].Error == nil { // could have failed earlier 109 | results[i].Error = AuthRulesErr{err} 110 | continue 111 | } 112 | } 113 | } 114 | 115 | // TODO: performSoftFailCheck, needs forward extremity 116 | return results, nil 117 | } 118 | 119 | type SignatureErr struct { 120 | err error 121 | } 122 | 123 | func (se SignatureErr) Error() string { 124 | return fmt.Sprintf("SignatureErr: %s", se.err) 125 | } 126 | 127 | func (se SignatureErr) Is(target error) bool { 128 | return strings.HasPrefix(target.Error(), "SignatureErr") 129 | } 130 | 131 | type AuthChainErr struct { 132 | err error 133 | } 134 | 135 | func (se AuthChainErr) Error() string { 136 | return fmt.Sprintf("AuthChainErr: %s", se.err) 137 | } 138 | 139 | func (se AuthChainErr) Is(target error) bool { 140 | return strings.HasPrefix(target.Error(), "AuthChainErr") 141 | } 142 | 143 | type AuthRulesErr struct { 144 | err error 145 | } 146 | 147 | func (se AuthRulesErr) Error() string { 148 | return fmt.Sprintf("AuthRulesErr: %s", se.err) 149 | } 150 | 151 | func (se AuthRulesErr) Is(target error) bool { 152 | return strings.HasPrefix(target.Error(), "AuthRulesErr") 153 | } 154 | -------------------------------------------------------------------------------- /spec/matrixerror_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Vector Creations Ltd 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package spec 16 | 17 | import ( 18 | "encoding/json" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/assert" 22 | ) 23 | 24 | func TestSimpleMatrixErrors(t *testing.T) { 25 | tests := map[string]struct { 26 | errorString string 27 | customErrMsg string 28 | errorFunc func(string) MatrixError 29 | }{ 30 | "m_unknown": {errorString: "M_UNKNOWN", errorFunc: Unknown}, 31 | "m_unrecognized": {errorString: "M_UNRECOGNIZED", errorFunc: Unrecognized}, 32 | "m_forbidden": {errorString: "M_FORBIDDEN", errorFunc: Forbidden}, 33 | "m_bad_json": {errorString: "M_BAD_JSON", errorFunc: BadJSON}, 34 | "m_bad_alias": {errorString: "M_BAD_ALIAS", errorFunc: BadAlias}, 35 | "m_not_json": {errorString: "M_NOT_JSON", errorFunc: NotJSON}, 36 | "m_not_found": {errorString: "M_NOT_FOUND", errorFunc: NotFound}, 37 | "m_missing_token": {errorString: "M_MISSING_TOKEN", errorFunc: MissingToken}, 38 | "m_unknown_token": {errorString: "M_UNKNOWN_TOKEN", errorFunc: UnknownToken}, 39 | "m_weak_password": {errorString: "M_WEAK_PASSWORD", errorFunc: WeakPassword}, 40 | "m_invalid_username": {errorString: "M_INVALID_USERNAME", errorFunc: InvalidUsername}, 41 | "m_user_in_use": {errorString: "M_USER_IN_USE", errorFunc: UserInUse}, 42 | "m_room_in_use": {errorString: "M_ROOM_IN_USE", errorFunc: RoomInUse}, 43 | "m_exclusive": {errorString: "M_EXCLUSIVE", errorFunc: ASExclusive}, 44 | "m_guest_access_forbidden": {errorString: "M_GUEST_ACCESS_FORBIDDEN", errorFunc: GuestAccessForbidden}, 45 | "m_invalid_signature": {errorString: "M_INVALID_SIGNATURE", errorFunc: InvalidSignature}, 46 | "m_invalid_param": {errorString: "M_INVALID_PARAM", errorFunc: InvalidParam}, 47 | "m_missing_param": {errorString: "M_MISSING_PARAM", errorFunc: MissingParam}, 48 | "m_unable_to_authorise_join": {errorString: "M_UNABLE_TO_AUTHORISE_JOIN", errorFunc: UnableToAuthoriseJoin}, 49 | "m_unsupported_room_version": {errorString: "M_UNSUPPORTED_ROOM_VERSION", errorFunc: UnsupportedRoomVersion}, 50 | "m_server_not_trusted": {errorString: "M_SERVER_NOT_TRUSTED", errorFunc: NotTrusted, customErrMsg: "Untrusted server 'error msg'"}, 51 | } 52 | 53 | for name, tc := range tests { 54 | t.Run(name, func(t *testing.T) { 55 | errorMsg := "error msg" 56 | e := tc.errorFunc(errorMsg) 57 | jsonBytes, err := json.Marshal(&e) 58 | if err != nil { 59 | t.Fatalf("Failed to marshal error. %s", err.Error()) 60 | } 61 | if tc.customErrMsg != "" { 62 | errorMsg = tc.customErrMsg 63 | } 64 | want := `{"errcode":"` + tc.errorString + `","error":"` + errorMsg + `"}` 65 | if string(jsonBytes) != want { 66 | t.Errorf("want %s, got %s", want, string(jsonBytes)) 67 | } 68 | }) 69 | } 70 | } 71 | 72 | func TestInternalServerError(t *testing.T) { 73 | e := InternalServerError{} 74 | assert.NotPanics(t, func() { _ = e.Error() }) 75 | } 76 | 77 | func TestLimitExceeded(t *testing.T) { 78 | e := LimitExceeded("error msg", 500) 79 | jsonBytes, err := json.Marshal(&e) 80 | if err != nil { 81 | t.Fatalf("Failed to marshal error. %s", err.Error()) 82 | } 83 | want := `{"errcode":"M_LIMIT_EXCEEDED","error":"error msg","retry_after_ms":500}` 84 | if string(jsonBytes) != want { 85 | t.Errorf("want %s, got %s", want, string(jsonBytes)) 86 | } 87 | } 88 | 89 | func TestLeaveServerNoticeError(t *testing.T) { 90 | e := LeaveServerNoticeError() 91 | jsonBytes, err := json.Marshal(&e) 92 | if err != nil { 93 | t.Fatalf("Failed to marshal error. %s", err.Error()) 94 | } 95 | want := `{"errcode":"M_CANNOT_LEAVE_SERVER_NOTICE_ROOM","error":"You cannot reject this invite"}` 96 | if string(jsonBytes) != want { 97 | t.Errorf("want %s, got %s", want, string(jsonBytes)) 98 | } 99 | } 100 | 101 | func TestWrongRoomKeysVersion(t *testing.T) { 102 | e := WrongBackupVersionError("error msg") 103 | jsonBytes, err := json.Marshal(&e) 104 | if err != nil { 105 | t.Fatalf("Failed to marshal error. %s", err.Error()) 106 | } 107 | want := `{"errcode":"M_WRONG_ROOM_KEYS_VERSION","error":"Wrong backup version.","current_version":"error msg"}` 108 | if string(jsonBytes) != want { 109 | t.Errorf("want %s, got %s", want, string(jsonBytes)) 110 | } 111 | } 112 | 113 | func TestIncompatibleRoomVersion(t *testing.T) { 114 | e := IncompatibleRoomVersion("error msg") 115 | jsonBytes, err := json.Marshal(&e) 116 | if err != nil { 117 | t.Fatalf("Failed to marshal error. %s", err.Error()) 118 | } 119 | want := `{"errcode":"M_INCOMPATIBLE_ROOM_VERSION","error":"Your homeserver does not support the features required to join this room","room_version":"error msg"}` 120 | if string(jsonBytes) != want { 121 | t.Errorf("want %s, got %s", want, string(jsonBytes)) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /eventV3.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/matrix-org/gomatrixserverlib/spec" 10 | "github.com/tidwall/gjson" 11 | "github.com/tidwall/sjson" 12 | ) 13 | 14 | type eventV3 struct { 15 | eventV2 16 | } 17 | 18 | func (e *eventV3) RoomID() spec.RoomID { 19 | roomIDStr := e.eventFields.RoomID 20 | isCreateEvent := e.Type() == spec.MRoomCreate && e.StateKeyEquals("") 21 | if isCreateEvent { 22 | roomIDStr = fmt.Sprintf("!%s", e.EventID()[1:]) 23 | } 24 | roomID, err := spec.NewRoomID(roomIDStr) 25 | if err != nil { 26 | panic(fmt.Errorf("RoomID is invalid: %w", err)) 27 | } 28 | return *roomID 29 | } 30 | 31 | func (e *eventV3) AuthEventIDs() []string { 32 | isCreateEvent := e.Type() == spec.MRoomCreate && e.StateKeyEquals("") 33 | if isCreateEvent { 34 | return []string{} 35 | } 36 | createEventID := fmt.Sprintf("$%s", e.eventFields.RoomID[1:]) 37 | if len(e.AuthEvents) > 0 { 38 | // always include the create event 39 | return append([]string{createEventID}, e.AuthEvents...) 40 | } 41 | return []string{createEventID} 42 | } 43 | 44 | func newEventFromUntrustedJSONV3(eventJSON []byte, roomVersion IRoomVersion) (PDU, error) { 45 | if r := gjson.GetBytes(eventJSON, "_*"); r.Exists() { 46 | return nil, fmt.Errorf("gomatrixserverlib NewEventFromUntrustedJSON: found top-level '_' key, is this a headered event: %v", string(eventJSON)) 47 | } 48 | if err := roomVersion.CheckCanonicalJSON(eventJSON); err != nil { 49 | return nil, BadJSONError{err} 50 | } 51 | 52 | res := &eventV3{} 53 | var err error 54 | // Synapse removes these keys from events in case a server accidentally added them. 55 | // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/crypto/event_signing.py#L57-L62 56 | for _, key := range []string{"outlier", "destinations", "age_ts", "unsigned", "event_id"} { 57 | if eventJSON, err = sjson.DeleteBytes(eventJSON, key); err != nil { 58 | return nil, err 59 | } 60 | } 61 | 62 | if err = json.Unmarshal(eventJSON, &res); err != nil { 63 | return nil, err 64 | } 65 | 66 | // v3 events have room IDs as the create event ID. 67 | // TODO: allow validation to be enhanced/relaxed to help users like Complement. 68 | if err := checkRoomID(res); err != nil { 69 | return nil, err 70 | } 71 | 72 | res.roomVersion = roomVersion.Version() 73 | 74 | // We know the JSON must be valid here. 75 | eventJSON = CanonicalJSONAssumeValid(eventJSON) 76 | res.eventJSON = eventJSON 77 | 78 | if err = checkEventContentHash(eventJSON); err != nil { 79 | res.redacted = true 80 | 81 | // If the content hash doesn't match then we have to discard all non-essential fields 82 | // because they've been tampered with. 83 | var redactedJSON []byte 84 | if redactedJSON, err = roomVersion.RedactEventJSON(eventJSON); err != nil { 85 | return nil, err 86 | } 87 | 88 | redactedJSON = CanonicalJSONAssumeValid(redactedJSON) 89 | 90 | // We need to ensure that `result` is the redacted event. 91 | // If redactedJSON is the same as eventJSON then `result` is already 92 | // correct. If not then we need to reparse. 93 | // 94 | // Yes, this means that for some events we parse twice (which is slow), 95 | // but means that parsing unredacted events is fast. 96 | if !bytes.Equal(redactedJSON, eventJSON) { 97 | result, err := roomVersion.NewEventFromTrustedJSON(redactedJSON, true) 98 | if err != nil { 99 | return nil, err 100 | } 101 | err = CheckFields(result) 102 | return result, err 103 | } 104 | } 105 | 106 | err = CheckFields(res) 107 | 108 | return res, err 109 | } 110 | 111 | func newEventFromTrustedJSONV3(eventJSON []byte, redacted bool, roomVersion IRoomVersion) (PDU, error) { 112 | res := eventV3{} 113 | if err := json.Unmarshal(eventJSON, &res); err != nil { 114 | return nil, err 115 | } 116 | 117 | // v3 events have room IDs as the create event ID. 118 | // TODO: allow validation to be enhanced/relaxed to help users like Complement. 119 | // TODO: feels weird to only have this validation here and not length checks etc :S 120 | if err := checkRoomID(&res); err != nil { 121 | return nil, err 122 | } 123 | 124 | res.roomVersion = roomVersion.Version() 125 | res.redacted = redacted 126 | res.eventJSON = eventJSON 127 | return &res, nil 128 | } 129 | 130 | func newEventFromTrustedJSONWithEventIDV3(eventID string, eventJSON []byte, redacted bool, roomVersion IRoomVersion) (PDU, error) { 131 | res := &eventV3{} 132 | if err := json.Unmarshal(eventJSON, &res); err != nil { 133 | return nil, err 134 | } 135 | 136 | // v3 events have room IDs as the create event ID. 137 | // TODO: allow validation to be enhanced/relaxed to help users like Complement. 138 | if err := checkRoomID(res); err != nil { 139 | return nil, err 140 | } 141 | 142 | res.roomVersion = roomVersion.Version() 143 | res.eventJSON = eventJSON 144 | res.EventIDRaw = eventID 145 | res.redacted = redacted 146 | return res, nil 147 | } 148 | 149 | func checkRoomID(res *eventV3) error { 150 | isCreateEvent := res.Type() == spec.MRoomCreate && res.StateKeyEquals("") 151 | // TODO: We can't do this so long as we support partial Hydra impls in Complement 152 | // because otherwise if MSC4291=0 and MSC4289=1 then this check fails as the create 153 | // event will have a room_id. 154 | //if isCreateEvent && res.eventFields.RoomID != "" { 155 | //return fmt.Errorf("gomatrixserverlib: room_id must not exist on create event") 156 | //} 157 | if !isCreateEvent && !strings.HasPrefix(res.eventFields.RoomID, "!") { 158 | return fmt.Errorf("gomatrixserverlib: room_id must start with !") 159 | } 160 | return nil 161 | } 162 | -------------------------------------------------------------------------------- /backfill.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/matrix-org/gomatrixserverlib/spec" 8 | ) 9 | 10 | // BackfillClient contains the necessary functions from the federation client to perform a backfill request 11 | // from another homeserver. 12 | type BackfillClient interface { 13 | // Backfill performs a backfill request to the given server. 14 | // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid 15 | Backfill(ctx context.Context, origin, server spec.ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) 16 | } 17 | 18 | // BackfillRequester contains the necessary functions to perform backfill requests from one server to another. 19 | // 20 | // It requires a StateProvider in order to perform PDU checks on received events, notably the step 21 | // "Passes authorization rules based on the state at the event, otherwise it is rejected.". The BackfillRequester 22 | // will always call functions on the StateProvider in topological order, starting with the earliest event and 23 | // rolling forwards. This allows implementations to make optimisations for subsequent events, rather than 24 | // constantly deferring to federation requests. 25 | type BackfillRequester interface { 26 | StateProvider 27 | BackfillClient 28 | // ServersAtEvent is called when trying to determine which server to request from. 29 | // It returns a list of servers which can be queried for backfill requests. These servers 30 | // will be servers that are in the room already. The entries at the beginning are preferred servers 31 | // and will be tried first. An empty list will fail the request. 32 | ServersAtEvent(ctx context.Context, roomID, eventID string) []spec.ServerName 33 | ProvideEvents(roomVer RoomVersion, eventIDs []string) ([]PDU, error) 34 | } 35 | 36 | // RequestBackfill implements the server logic for making backfill requests to other servers. 37 | // This handles server selection, fetching up to the request limit and verifying the received events. 38 | // Event validation also includes authorisation checks, which may require additional state to be fetched. 39 | // 40 | // The returned events are safe to be inserted into a database for later retrieval. It's possible for the 41 | // number of returned events to be less than the limit, even if there exists more events. It's also possible 42 | // for the number of returned events to be greater than the limit, if fromEventIDs > 1 and we need to ask 43 | // multiple servers. We don't drop events greater than the limit because we've already done all the work to 44 | // verify them, so it's up to the caller to decide what to do with them. 45 | // 46 | // TODO: We should be able to make some guarantees for the caller about the returned events position in the DAG, 47 | // but to verify it we need to know the prev_events of fromEventIDs. 48 | // 49 | // TODO: When does it make sense to return errors? 50 | func RequestBackfill(ctx context.Context, origin spec.ServerName, b BackfillRequester, keyRing JSONVerifier, 51 | roomID string, ver RoomVersion, fromEventIDs []string, limit int, userIDForSender spec.UserIDForSender) ([]PDU, error) { 52 | 53 | if len(fromEventIDs) == 0 { 54 | return nil, nil 55 | } 56 | haveEventIDs := make(map[string]bool) 57 | var result []PDU 58 | loader := NewEventsLoader(ver, keyRing, b, b.ProvideEvents, false) 59 | // pick a server to backfill from 60 | // TODO: use other event IDs and make a set out of all the returned servers? 61 | servers := b.ServersAtEvent(ctx, roomID, fromEventIDs[0]) 62 | // loop each server asking it for `limit` events. Worst case, we ask every server for `limit` 63 | // events before giving up. Best case, we just ask one. 64 | var lastErr error 65 | for _, s := range servers { 66 | if len(result) >= limit { 67 | break 68 | } 69 | if ctx.Err() != nil { 70 | return nil, fmt.Errorf("gomatrixserverlib: RequestBackfill context cancelled %w", ctx.Err()) 71 | } 72 | // fetch some events, and try a different server if it fails 73 | txn, err := b.Backfill(ctx, origin, s, roomID, limit, fromEventIDs) 74 | if err != nil { 75 | lastErr = err 76 | continue // try the next server 77 | } 78 | // topologically sort the events so implementations of 'get state at event' can do optimisations 79 | loadResults, err := loader.LoadAndVerify(ctx, txn.PDUs, TopologicalOrderByPrevEvents, userIDForSender) 80 | if err != nil { 81 | lastErr = err 82 | continue // try the next server 83 | } 84 | for _, res := range loadResults { 85 | switch res.Error.(type) { 86 | case nil, SignatureErr: 87 | // The signature of the event might not be valid anymore, for example if 88 | // the key ID was reused with a different signature. 89 | case AuthChainErr, AuthRulesErr: 90 | continue 91 | default: 92 | continue 93 | } 94 | if haveEventIDs[res.Event.EventID()] { 95 | continue // we got this event from a different server 96 | } 97 | haveEventIDs[res.Event.EventID()] = true 98 | result = append(result, res.Event) 99 | } 100 | } 101 | 102 | // Since we pulled in results from multiple servers we need to sort again... 103 | return ReverseTopologicalOrdering(result, TopologicalOrderByPrevEvents), lastErr 104 | } 105 | 106 | /* 107 | // BackfillResponder contains the necessary functions to handle backfill requests. 108 | type backfillResponder interface { 109 | // TODO, unexported for now. 110 | } 111 | 112 | // ReceiveBackfill implements the server logic for processing backfill requests sent by a server. 113 | // This handles event selection via breadth-first search, as well as history visibility rules depending 114 | // on the state of the room at that point in time. 115 | func receiveBackfill(b backfillResponder, roomID string, fromEventIDs []string, limit int) (*Transaction, error) { 116 | return nil, nil // TODO, unexported for now. 117 | } 118 | */ 119 | -------------------------------------------------------------------------------- /spec/base64_test.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package spec 17 | 18 | import ( 19 | "bytes" 20 | "encoding/json" 21 | "testing" 22 | 23 | "gopkg.in/yaml.v2" 24 | ) 25 | 26 | func TestMarshalBase64(t *testing.T) { 27 | input := Base64Bytes("this\xffis\xffa\xfftest") 28 | want := `"dGhpc/9pc/9h/3Rlc3Q"` 29 | got, err := json.Marshal(input) 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | if string(got) != want { 34 | t.Fatalf("json.Marshal(Base64Bytes(%q)): wanted %q got %q", string(input), want, string(got)) 35 | } 36 | } 37 | 38 | func TestUnmarshalBase64(t *testing.T) { 39 | input := []byte(`"dGhpc/9pc/9h/3Rlc3Q"`) 40 | want := "this\xffis\xffa\xfftest" // nolint:goconst 41 | var got Base64Bytes 42 | err := json.Unmarshal(input, &got) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | if string(got) != want { 47 | t.Fatalf("json.Unmarshal(%q): wanted %q got %q", string(input), want, string(got)) 48 | } 49 | } 50 | 51 | func TestUnmarshalUrlSafeBase64(t *testing.T) { 52 | input := []byte(`"dGhpc_9pc_9h_3Rlc3Q"`) 53 | want := "this\xffis\xffa\xfftest" 54 | var got Base64Bytes 55 | err := json.Unmarshal(input, &got) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | if string(got) != want { 60 | t.Fatalf("json.Unmarshal(%q): wanted %q got %q", string(input), want, string(got)) 61 | } 62 | } 63 | 64 | func TestMarshalBase64Struct(t *testing.T) { 65 | input := struct{ Value Base64Bytes }{Base64Bytes("this\xffis\xffa\xfftest")} 66 | want := `{"Value":"dGhpc/9pc/9h/3Rlc3Q"}` 67 | got, err := json.Marshal(input) 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | if string(got) != want { 72 | t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) 73 | } 74 | } 75 | 76 | func TestMarshalBase64Map(t *testing.T) { 77 | input := map[string]Base64Bytes{"Value": Base64Bytes("this\xffis\xffa\xfftest")} 78 | want := `{"Value":"dGhpc/9pc/9h/3Rlc3Q"}` 79 | got, err := json.Marshal(input) 80 | if err != nil { 81 | t.Fatal(err) 82 | } 83 | if string(got) != want { 84 | t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) 85 | } 86 | } 87 | 88 | func TestMarshalBase64Slice(t *testing.T) { 89 | input := []Base64Bytes{Base64Bytes("this\xffis\xffa\xfftest")} 90 | want := `["dGhpc/9pc/9h/3Rlc3Q"]` 91 | got, err := json.Marshal(input) 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | if string(got) != want { 96 | t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) 97 | } 98 | } 99 | 100 | func TestMarshalYAMLBase64(t *testing.T) { 101 | input := Base64Bytes("this\xffis\xffa\xfftest") 102 | want := "dGhpc/9pc/9h/3Rlc3Q\n" 103 | got, err := yaml.Marshal(input) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | if string(got) != want { 108 | t.Fatalf("yaml.Marshal(%v): wanted %q got %q", input, want, string(got)) 109 | } 110 | } 111 | 112 | func TestMarshalYAMLBase64Struct(t *testing.T) { 113 | input := struct{ Value Base64Bytes }{Base64Bytes("this\xffis\xffa\xfftest")} 114 | want := "value: dGhpc/9pc/9h/3Rlc3Q\n" 115 | got, err := yaml.Marshal(input) 116 | if err != nil { 117 | t.Fatal(err) 118 | } 119 | if string(got) != want { 120 | t.Fatalf("yaml.Marshal(%v): wanted %q got %q", input, want, string(got)) 121 | } 122 | } 123 | 124 | func TestUnmarshalYAMLBase64(t *testing.T) { 125 | input := []byte("dGhpc/9pc/9h/3Rlc3Q") 126 | want := Base64Bytes("this\xffis\xffa\xfftest") 127 | var got Base64Bytes 128 | err := yaml.Unmarshal(input, &got) 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | if string(got) != string(want) { 133 | t.Fatalf("yaml.Unmarshal(%q): wanted %q got %q", string(input), want, string(got)) 134 | } 135 | } 136 | 137 | func TestUnmarshalYAMLBase64Struct(t *testing.T) { 138 | // var u yaml.Unmarshaler 139 | u := Base64Bytes("this\xffis\xffa\xfftest") 140 | 141 | input := []byte(`value: dGhpc/9pc/9h/3Rlc3Q`) 142 | want := struct{ Value Base64Bytes }{u} 143 | result := struct { 144 | Value Base64Bytes `yaml:"value"` 145 | }{} 146 | err := yaml.Unmarshal(input, &result) 147 | if err != nil { 148 | t.Fatal(err) 149 | } 150 | if string(result.Value) != string(want.Value) { 151 | t.Fatalf("yaml.Unmarshal(%v): wanted %q got %q", input, want, result) 152 | } 153 | } 154 | 155 | func TestScanBase64(t *testing.T) { 156 | expecting := Base64Bytes("This is a test string") 157 | 158 | inputStr := "VGhpcyBpcyBhIHRlc3Qgc3RyaW5n" 159 | inputJSON := RawJSON(`"` + inputStr + `"`) 160 | inputBytes := []byte(expecting) 161 | inputInt := 3 162 | 163 | var b Base64Bytes 164 | 165 | if err := b.Scan(inputStr); err != nil { 166 | t.Fatal(err) 167 | } 168 | if !bytes.Equal(expecting, b) { 169 | t.Fatalf("scanning from string failed, got %v, wanted %v", b, expecting) 170 | } 171 | 172 | if err := b.Scan(inputJSON); err != nil { 173 | t.Fatal(err) 174 | } 175 | if !bytes.Equal(expecting, b) { 176 | t.Fatalf("scanning from RawJSON failed, got %v, wanted %v", b, expecting) 177 | } 178 | 179 | if err := b.Scan(inputBytes); err != nil { 180 | t.Fatal(err) 181 | } 182 | if !bytes.Equal(expecting, b) { 183 | t.Fatalf("scanning from []byte failed, got %v, wanted %v", b, expecting) 184 | } 185 | 186 | if err := b.Scan(inputInt); err == nil { 187 | t.Fatal("scanning from int should have failed but didn't") 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /fclient/resolve.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package fclient 17 | 18 | import ( 19 | "context" 20 | "fmt" 21 | "net" 22 | "strconv" 23 | 24 | "github.com/matrix-org/gomatrixserverlib/spec" 25 | ) 26 | 27 | // ResolutionResult is a result of looking up a Matrix homeserver according to 28 | // the federation specification. 29 | type ResolutionResult struct { 30 | Destination string // The hostname and port to send federation requests to. 31 | Host spec.ServerName // The value of the Host headers. 32 | TLSServerName string // The TLS server name to request a certificate for. 33 | } 34 | 35 | // ResolveServer implements the server name resolution algorithm described at 36 | // https://matrix.org/docs/spec/server_server/r0.1.1.html#resolving-server-names 37 | // Returns a slice of ResolutionResult that can be used to send a federation 38 | // request to the server using a given server name. 39 | // Returns an error if the server name isn't valid. 40 | func ResolveServer(ctx context.Context, serverName spec.ServerName) (results []ResolutionResult, err error) { 41 | return resolveServer(ctx, serverName, true) 42 | } 43 | 44 | // resolveServer does the same thing as ResolveServer, except it also requires 45 | // the checkWellKnown parameter, which indicates whether a .well-known file 46 | // should be looked up. 47 | func resolveServer(ctx context.Context, serverName spec.ServerName, checkWellKnown bool) (results []ResolutionResult, err error) { 48 | host, port, valid := spec.ParseAndValidateServerName(serverName) 49 | if !valid { 50 | err = fmt.Errorf("Invalid server name") 51 | return 52 | } 53 | 54 | // 1. If the hostname is an IP literal 55 | // Check if we're dealing with an IPv6 literal with square brackets. If so, 56 | // remove the brackets. 57 | if host[0] == '[' && host[len(host)-1] == ']' { 58 | host = host[1 : len(host)-1] 59 | } 60 | if net.ParseIP(host) != nil { 61 | var destination string 62 | 63 | if port == -1 { 64 | destination = net.JoinHostPort(host, strconv.Itoa(8448)) 65 | } else { 66 | destination = string(serverName) 67 | } 68 | 69 | results = []ResolutionResult{ 70 | { 71 | Destination: destination, 72 | Host: serverName, 73 | TLSServerName: host, 74 | }, 75 | } 76 | 77 | return 78 | } 79 | 80 | // 2. If the hostname is not an IP literal, and the server name includes an 81 | // explicit port 82 | if port != -1 { 83 | results = []ResolutionResult{ 84 | { 85 | Destination: string(serverName), 86 | Host: serverName, 87 | TLSServerName: host, 88 | }, 89 | } 90 | 91 | return 92 | } 93 | 94 | if checkWellKnown { 95 | // 3. If the hostname is not an IP literal 96 | var result *WellKnownResult 97 | result, err = LookupWellKnown(ctx, serverName) 98 | if err == nil { 99 | // We don't want to check .well-known on the result 100 | return resolveServer(ctx, result.NewAddress, false) 101 | } 102 | } 103 | 104 | return handleNoWellKnown(ctx, serverName), nil 105 | } 106 | 107 | // handleNoWellKnown implements steps 4 and 5 of the resolution algorithm (as 108 | // well as 3.3 and 3.4) 109 | func handleNoWellKnown(ctx context.Context, serverName spec.ServerName) (results []ResolutionResult) { 110 | // 4. If the /.well-known request resulted in an error response 111 | records, err := lookupSRV(ctx, serverName) 112 | if err == nil && len(records) > 0 { 113 | for _, rec := range records { 114 | // If the domain is a FQDN, remove the trailing dot at the end. This 115 | // isn't critical to send the request, as Go's HTTP client and most 116 | // servers understand FQDNs quite well, but it makes automated 117 | // testing easier. 118 | target := rec.Target 119 | if target[len(target)-1] == '.' { 120 | target = target[:len(target)-1] 121 | } 122 | 123 | results = append(results, ResolutionResult{ 124 | Destination: fmt.Sprintf("%s:%d", target, rec.Port), 125 | Host: serverName, 126 | TLSServerName: string(serverName), 127 | }) 128 | } 129 | 130 | return 131 | } 132 | 133 | // 5. If the /.well-known request returned an error response, and the SRV 134 | // record was not found 135 | results = []ResolutionResult{ 136 | { 137 | Destination: fmt.Sprintf("%s:%d", serverName, 8448), 138 | Host: serverName, 139 | TLSServerName: string(serverName), 140 | }, 141 | } 142 | 143 | return 144 | } 145 | 146 | func lookupSRV(ctx context.Context, serverName spec.ServerName) ([]*net.SRV, error) { 147 | // Check matrix-fed service first, as of Matrix 1.8 148 | _, records, err := net.DefaultResolver.LookupSRV(ctx, "matrix-fed", "tcp", string(serverName)) 149 | if err != nil { 150 | if dnserr, ok := err.(*net.DNSError); ok { 151 | if !dnserr.IsNotFound { 152 | // not found errors are expected, but everything else is very much not 153 | return records, err 154 | } 155 | } else { 156 | return records, err 157 | } 158 | } else { 159 | return records, nil // we got a hit on the matrix-fed service, so use that 160 | } 161 | 162 | // we didn't get a hit on matrix-fed, so try deprecated matrix service 163 | _, records, err = net.DefaultResolver.LookupSRV(ctx, "matrix", "tcp", string(serverName)) 164 | return records, err // we don't need to process this here 165 | } 166 | -------------------------------------------------------------------------------- /keys.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import ( 19 | "encoding/json" 20 | "strings" 21 | "time" 22 | 23 | "github.com/matrix-org/gomatrixserverlib/spec" 24 | ) 25 | 26 | // ServerKeys are the ed25519 signing keys published by a matrix server. 27 | // Contains SHA256 fingerprints of the TLS X509 certificates used by the server. 28 | type ServerKeys struct { 29 | // Copy of the raw JSON for signature checking. 30 | Raw []byte 31 | // The decoded JSON fields. 32 | ServerKeyFields 33 | } 34 | 35 | // A VerifyKey is a ed25519 public key for a server. 36 | type VerifyKey struct { 37 | // The public key. 38 | Key spec.Base64Bytes `json:"key"` 39 | } 40 | 41 | // An OldVerifyKey is an old ed25519 public key that is no longer valid. 42 | type OldVerifyKey struct { 43 | VerifyKey 44 | // When this key stopped being valid for event signing in milliseconds. 45 | ExpiredTS spec.Timestamp `json:"expired_ts"` 46 | } 47 | 48 | // ServerKeyFields are the parsed JSON contents of the ed25519 signing keys published by a matrix server. 49 | type ServerKeyFields struct { 50 | // The name of the server 51 | ServerName spec.ServerName `json:"server_name"` 52 | // The current signing keys in use on this server. 53 | // The keys of the map are the IDs of the keys. 54 | // These are valid while this response is valid. 55 | VerifyKeys map[KeyID]VerifyKey `json:"verify_keys"` 56 | // When this result is valid until in milliseconds. 57 | ValidUntilTS spec.Timestamp `json:"valid_until_ts"` 58 | // Old keys that are now only valid for checking historic events. 59 | // The keys of the map are the IDs of the keys. 60 | OldVerifyKeys map[KeyID]OldVerifyKey `json:"old_verify_keys"` 61 | } 62 | 63 | // UnmarshalJSON implements json.Unmarshaler 64 | func (keys *ServerKeys) UnmarshalJSON(data []byte) error { 65 | keys.Raw = data 66 | return json.Unmarshal(data, &keys.ServerKeyFields) 67 | } 68 | 69 | // MarshalJSON implements json.Marshaler 70 | func (keys ServerKeys) MarshalJSON() ([]byte, error) { 71 | if len(keys.Raw) == 0 { 72 | js, err := json.Marshal(keys.ServerKeyFields) 73 | if err != nil { 74 | return nil, err 75 | } 76 | return js, nil 77 | } 78 | // We already have a copy of the serialised JSON for the keys so we can return that directly. 79 | return keys.Raw, nil 80 | } 81 | 82 | // PublicKey returns a public key with the given ID valid at the given TS or nil if no such key exists. 83 | func (keys ServerKeys) PublicKey(keyID KeyID, atTS spec.Timestamp) []byte { 84 | if currentKey, ok := keys.VerifyKeys[keyID]; ok && (atTS <= keys.ValidUntilTS) { 85 | return currentKey.Key 86 | } 87 | if oldKey, ok := keys.OldVerifyKeys[keyID]; ok && (atTS <= oldKey.ExpiredTS) { 88 | return oldKey.Key 89 | } 90 | return nil 91 | } 92 | 93 | // Ed25519Checks are the checks that are applied to Ed25519 keys in ServerKey responses. 94 | type Ed25519Checks struct { 95 | ValidEd25519 bool // The verify key is valid Ed25519 keys. 96 | MatchingSignature bool // The verify key has a valid signature. 97 | } 98 | 99 | // KeyChecks are the checks that should be applied to ServerKey responses. 100 | type KeyChecks struct { 101 | AllChecksOK bool // Did all the checks pass? 102 | MatchingServerName bool // Does the server name match what was requested. 103 | FutureValidUntilTS bool // The valid until TS is in the future. 104 | HasEd25519Key bool // The server has at least one ed25519 key. 105 | AllEd25519ChecksOK *bool // All the Ed25519 checks are ok. or null if there weren't any to check. 106 | Ed25519Checks map[KeyID]Ed25519Checks // Checks for Ed25519 keys. 107 | } 108 | 109 | // CheckKeys checks the keys returned from a server to make sure they are valid. 110 | // If the checks pass then also return a map of key_id to Ed25519 public key 111 | func CheckKeys( 112 | serverName spec.ServerName, 113 | now time.Time, 114 | keys ServerKeys, 115 | ) ( 116 | checks KeyChecks, 117 | ed25519Keys map[KeyID]spec.Base64Bytes, 118 | ) { 119 | checks.MatchingServerName = serverName == keys.ServerName 120 | checks.FutureValidUntilTS = keys.ValidUntilTS.Time().After(now) 121 | checks.AllChecksOK = checks.MatchingServerName && checks.FutureValidUntilTS 122 | 123 | ed25519Keys = checkVerifyKeys(keys, &checks) 124 | 125 | if !checks.AllChecksOK { 126 | ed25519Keys = nil 127 | } 128 | return 129 | } 130 | 131 | func checkVerifyKeys(keys ServerKeys, checks *KeyChecks) map[KeyID]spec.Base64Bytes { 132 | allEd25519ChecksOK := true 133 | checks.Ed25519Checks = map[KeyID]Ed25519Checks{} 134 | verifyKeys := map[KeyID]spec.Base64Bytes{} 135 | for keyID, keyData := range keys.VerifyKeys { 136 | algorithm := strings.SplitN(string(keyID), ":", 2)[0] 137 | publicKey := keyData.Key 138 | if algorithm == "ed25519" { 139 | checks.HasEd25519Key = true 140 | checks.AllEd25519ChecksOK = &allEd25519ChecksOK 141 | entry := Ed25519Checks{ 142 | ValidEd25519: len(publicKey) == 32, 143 | } 144 | if entry.ValidEd25519 { 145 | err := VerifyJSON(string(keys.ServerName), keyID, []byte(publicKey), keys.Raw) 146 | entry.MatchingSignature = err == nil 147 | } 148 | checks.Ed25519Checks[keyID] = entry 149 | if entry.MatchingSignature { 150 | verifyKeys[keyID] = publicKey 151 | } else { 152 | allEd25519ChecksOK = false 153 | } 154 | } 155 | } 156 | if checks.AllChecksOK { 157 | checks.AllChecksOK = checks.HasEd25519Key && allEd25519ChecksOK 158 | } 159 | return verifyKeys 160 | } 161 | -------------------------------------------------------------------------------- /signing_test.go: -------------------------------------------------------------------------------- 1 | /* Copyright 2016-2017 Vector Creations Ltd 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package gomatrixserverlib 17 | 18 | import ( 19 | "bytes" 20 | "encoding/base64" 21 | "encoding/json" 22 | "testing" 23 | 24 | "golang.org/x/crypto/ed25519" 25 | ) 26 | 27 | func TestVerifyJSON(t *testing.T) { 28 | // Check JSON verification using the test vectors from https://matrix.org/docs/spec/appendices.html 29 | seed, err := base64.RawStdEncoding.DecodeString("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | random := bytes.NewBuffer(seed) 34 | entityName := "domain" 35 | keyID := KeyID("ed25519:1") 36 | 37 | publicKey, _, err := ed25519.GenerateKey(random) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | testVerifyOK := func(input string) { 43 | err := VerifyJSON(entityName, keyID, publicKey, []byte(input)) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | } 48 | 49 | testVerifyNotOK := func(reason, input string) { 50 | err := VerifyJSON(entityName, keyID, publicKey, []byte(input)) 51 | if err == nil { 52 | t.Fatalf("Expected VerifyJSON to fail for input %v because %v", input, reason) 53 | } 54 | } 55 | 56 | testVerifyOK(`{ 57 | "signatures": { 58 | "domain": { 59 | "ed25519:1": "K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ" 60 | } 61 | } 62 | }`) 63 | 64 | testVerifyNotOK("the json is modified", `{ 65 | "a new key": "a new value", 66 | "signatures": { 67 | "domain": { 68 | "ed25519:1": "K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ" 69 | } 70 | } 71 | }`) 72 | testVerifyNotOK("the signature is modified", `{ 73 | "a new key": "a new value", 74 | "signatures": { 75 | "domain": { 76 | "ed25519:1": "modifiedSSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ" 77 | } 78 | } 79 | }`) 80 | testVerifyNotOK("there are no signatures", `{}`) 81 | testVerifyNotOK("there are no signatures", `{"signatures": {}}`) 82 | 83 | testVerifyNotOK("there are not signatures for domain", `{ 84 | "signatures": {"domain": {}} 85 | }`) 86 | testVerifyNotOK("the signature has the wrong key_id", `{ 87 | "signatures": { "domain": { 88 | "ed25519:2":"KqmLSbO39/Bzb0QIYE82zqLwsA+PDzYIpIRA2sRQ4sL53+sN6/fpNSoqE7BP7vBZhG6kYdD13EIMJpvhJI+6Bw" 89 | }} 90 | }`) 91 | testVerifyNotOK("the signature is too short for ed25519", `{"signatures": {"domain": {"ed25519:1":"not/a/valid/signature"}}}`) 92 | testVerifyNotOK("the signature has base64 padding that it shouldn't have", `{ 93 | "signatures": { "domain": { 94 | "ed25519:1": "K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ==" 95 | }} 96 | }`) 97 | } 98 | 99 | func TestSignJSON(t *testing.T) { 100 | random := bytes.NewBuffer([]byte("Some 32 randomly generated bytes")) 101 | entityName := "example.com" 102 | keyID := KeyID("ed25519:my_key_id") 103 | input := []byte(`{"this":"is","my":"message"}`) 104 | 105 | publicKey, privateKey, err := ed25519.GenerateKey(random) 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | 110 | signed, err := SignJSON(entityName, keyID, privateKey, input) 111 | if err != nil { 112 | t.Fatal(err) 113 | } 114 | 115 | err = VerifyJSON(entityName, keyID, publicKey, signed) 116 | if err != nil { 117 | t.Errorf("VerifyJSON(%q)", signed) 118 | t.Fatal(err) 119 | } 120 | } 121 | 122 | func IsJSONEqual(a, b []byte) bool { 123 | canonicalA, err := CanonicalJSON(a) 124 | if err != nil { 125 | panic(err) 126 | } 127 | canonicalB, err := CanonicalJSON(b) 128 | if err != nil { 129 | panic(err) 130 | } 131 | return string(canonicalA) == string(canonicalB) 132 | } 133 | 134 | func TestSignJSONTestVectors(t *testing.T) { 135 | // Check JSON signing using the test vectors from https://matrix.org/docs/spec/appendices.html 136 | seed, err := base64.RawStdEncoding.DecodeString("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") 137 | if err != nil { 138 | t.Fatal(err) 139 | } 140 | random := bytes.NewBuffer(seed) 141 | entityName := "domain" 142 | keyID := KeyID("ed25519:1") 143 | 144 | _, privateKey, err := ed25519.GenerateKey(random) 145 | if err != nil { 146 | t.Fatal(err) 147 | } 148 | 149 | testSign := func(input string, want string) { 150 | signed, err := SignJSON(entityName, keyID, privateKey, []byte(input)) 151 | if err != nil { 152 | t.Fatal(err) 153 | } 154 | 155 | if !IsJSONEqual([]byte(want), signed) { 156 | t.Fatalf("VerifyJSON(%q): want %v got %v", input, want, string(signed)) 157 | } 158 | } 159 | 160 | testSign(`{}`, `{ 161 | "signatures":{ 162 | "domain":{ 163 | "ed25519:1":"K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ" 164 | } 165 | } 166 | }`) 167 | 168 | testSign(`{"one":1,"two":"Two"}`, `{ 169 | "one": 1, 170 | "signatures": { 171 | "domain": { 172 | "ed25519:1": "KqmLSbO39/Bzb0QIYE82zqLwsA+PDzYIpIRA2sRQ4sL53+sN6/fpNSoqE7BP7vBZhG6kYdD13EIMJpvhJI+6Bw" 173 | } 174 | }, 175 | "two": "Two" 176 | }`) 177 | } 178 | 179 | type MyMessage struct { 180 | Unsigned *json.RawMessage `json:"unsigned"` 181 | Content *json.RawMessage `json:"content"` 182 | Signatures *json.RawMessage `json:"signatures,omitempty"` 183 | } 184 | 185 | func TestSignJSONWithUnsigned(t *testing.T) { 186 | random := bytes.NewBuffer([]byte("Some 32 randomly generated bytes")) 187 | entityName := "example.com" 188 | keyID := KeyID("ed25519:my_key_id") 189 | content := json.RawMessage(`{"signed":"data"}`) 190 | unsigned := json.RawMessage(`{"unsigned":"data"}`) 191 | message := MyMessage{&unsigned, &content, nil} 192 | 193 | input, err := json.Marshal(&message) 194 | if err != nil { 195 | t.Fatal(err) 196 | } 197 | 198 | publicKey, privateKey, err := ed25519.GenerateKey(random) 199 | if err != nil { 200 | t.Fatal(err) 201 | } 202 | 203 | signed, err := SignJSON(entityName, keyID, privateKey, input) 204 | if err != nil { 205 | t.Fatal(err) 206 | } 207 | 208 | if err2 := json.Unmarshal(signed, &message); err2 != nil { 209 | t.Fatal(err2) 210 | } 211 | newUnsigned := json.RawMessage(`{"different":"data"}`) 212 | message.Unsigned = &newUnsigned 213 | input, err = json.Marshal(&message) 214 | if err != nil { 215 | t.Fatal(err) 216 | } 217 | 218 | err = VerifyJSON(entityName, keyID, publicKey, input) 219 | if err != nil { 220 | t.Errorf("VerifyJSON(%q)", signed) 221 | t.Fatal(err) 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /redactevent_test.go: -------------------------------------------------------------------------------- 1 | package gomatrixserverlib 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestRedactionAlgorithmV4(t *testing.T) { 9 | // Specifically, the version 4 redaction algorithm used in room 10 | // version 9 is ensuring that the `join_authorised_via_users_server` 11 | // key doesn't get redacted. 12 | 13 | input := []byte(`{"content":{"avatar_url":"mxc://something/somewhere","displayname":"Someone","join_authorised_via_users_server":"@someoneelse:somewhere.org","membership":"join"},"origin_server_ts":1633108629915,"sender":"@someone:somewhere.org","state_key":"@someone:somewhere.org","type":"m.room.member","unsigned":{"age":539338},"room_id":"!someroom:matrix.org"}`) 14 | expectedv8 := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{"membership":"join"},"type":"m.room.member","state_key":"@someone:somewhere.org","origin_server_ts":1633108629915}`)) 15 | expectedv9 := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{"membership":"join","join_authorised_via_users_server":"@someoneelse:somewhere.org"},"type":"m.room.member","state_key":"@someone:somewhere.org","origin_server_ts":1633108629915}`)) 16 | 17 | redactedv8, err := MustGetRoomVersion(RoomVersionV8).RedactEventJSON(input) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | 22 | redactedv9, err := MustGetRoomVersion(RoomVersionV9).RedactEventJSON(input) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | redactedv8 = CanonicalJSONAssumeValid(redactedv8) 27 | redactedv9 = CanonicalJSONAssumeValid(redactedv9) 28 | 29 | if !bytes.Equal(redactedv8, expectedv8) { 30 | t.Fatalf("room version 8 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv8), string(redactedv8)) 31 | } 32 | 33 | if !bytes.Equal(redactedv9, expectedv9) { 34 | t.Fatalf("room version 9 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv9), string(redactedv9)) 35 | } 36 | 37 | redactedv8withv9, err := MustGetRoomVersion(RoomVersionV9).RedactEventJSON(expectedv8) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | redactedv8withv9 = CanonicalJSONAssumeValid(redactedv8withv9) 42 | if !bytes.Equal(redactedv8withv9, expectedv8) { 43 | t.Fatalf("room version 8 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv8), string(redactedv8withv9)) 44 | } 45 | } 46 | 47 | func TestRedactionAlgorithmV5(t *testing.T) { 48 | // Specifically, the version 5 redaction algorithm used in room 49 | // version 11 is ensuring that: 50 | // - `m.room.create` keeps all `content` fields 51 | // - `m.room.redaction` keeps `redacts` `content` field 52 | // - `m.room.power_levels` keeps `invite` `content` field 53 | // - top level `origin`, `membership`, and `prev_state` aren't protected from redaction 54 | 55 | input := []byte(`{"content":{"placeholder":"value"},"origin_server_ts":1633108629915,"sender":"@someone:somewhere.org","state_key":"@someone:somewhere.org","type":"m.room.create","unsigned":{"age":539338},"room_id":"!someroom:matrix.org","origin":"matrix.org","membership":"join","prev_state":""}`) 56 | expectedv10 := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{},"type":"m.room.create","state_key":"@someone:somewhere.org","prev_state":"","origin":"matrix.org","origin_server_ts":1633108629915,"membership":"join"}`)) 57 | expectedv11 := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{"placeholder":"value"},"type":"m.room.create","state_key":"@someone:somewhere.org","origin_server_ts":1633108629915}`)) 58 | expectedv10withv11 := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{},"type":"m.room.create","state_key":"@someone:somewhere.org","origin_server_ts":1633108629915}`)) 59 | 60 | redactedv10, err := MustGetRoomVersion(RoomVersionV10).RedactEventJSON(input) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | redactedv11, err := MustGetRoomVersion(RoomVersionV11).RedactEventJSON(input) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | redactedv10 = CanonicalJSONAssumeValid(redactedv10) 70 | redactedv11 = CanonicalJSONAssumeValid(redactedv11) 71 | 72 | if !bytes.Equal(redactedv10, expectedv10) { 73 | t.Fatalf("room version 10 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv10), string(redactedv10)) 74 | } 75 | 76 | if !bytes.Equal(redactedv11, expectedv11) { 77 | t.Fatalf("room version 11 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv11), string(redactedv11)) 78 | } 79 | 80 | redactedv10withv11, err := MustGetRoomVersion(RoomVersionV11).RedactEventJSON(expectedv10) 81 | if err != nil { 82 | t.Fatal(err) 83 | } 84 | redactedv10withv11 = CanonicalJSONAssumeValid(redactedv10withv11) 85 | if !bytes.Equal(redactedv10withv11, expectedv10withv11) { 86 | t.Fatalf("room version 11 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv10withv11), string(redactedv10withv11)) 87 | } 88 | 89 | powerLevelsInput := []byte(`{"content":{"invite":"","placeholder":"value"},"origin_server_ts":1633108629915,"sender":"@someone:somewhere.org","state_key":"@someone:somewhere.org","type":"m.room.power_levels","unsigned":{"age":539338},"room_id":"!someroom:matrix.org","origin":"matrix.org","membership":"join","prev_state":""}`) 90 | expectedv11PLs := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{"invite":""},"type":"m.room.power_levels","state_key":"@someone:somewhere.org","origin_server_ts":1633108629915}`)) 91 | 92 | redactedv11PLs, err := MustGetRoomVersion(RoomVersionV11).RedactEventJSON(powerLevelsInput) 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | redactedv11PLs = CanonicalJSONAssumeValid(redactedv11PLs) 97 | if !bytes.Equal(redactedv11PLs, expectedv11PLs) { 98 | t.Fatalf("room version 11 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv11PLs), string(redactedv11PLs)) 99 | } 100 | 101 | readactionInput := []byte(`{"content":{"redacts":"","placeholder":"value"},"origin_server_ts":1633108629915,"sender":"@someone:somewhere.org","state_key":"@someone:somewhere.org","type":"m.room.redaction","unsigned":{"age":539338},"room_id":"!someroom:matrix.org","origin":"matrix.org","membership":"join","prev_state":""}`) 102 | expectedv11Redaction := CanonicalJSONAssumeValid([]byte(`{"sender":"@someone:somewhere.org","room_id":"!someroom:matrix.org","content":{"redacts":""},"type":"m.room.redaction","state_key":"@someone:somewhere.org","origin_server_ts":1633108629915}`)) 103 | 104 | redactedv11Redaction, err := MustGetRoomVersion(RoomVersionV11).RedactEventJSON(readactionInput) 105 | if err != nil { 106 | t.Fatal(err) 107 | } 108 | redactedv11Redaction = CanonicalJSONAssumeValid(redactedv11Redaction) 109 | if !bytes.Equal(redactedv11Redaction, expectedv11Redaction) { 110 | t.Fatalf("room version 11 redaction produced unexpected result\nexpected: %s\ngot: %s", string(expectedv11Redaction), string(redactedv11Redaction)) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= 6 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= 7 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 8 | github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= 9 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 10 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 11 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 12 | github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= 13 | github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= 14 | github.com/hashicorp/go-set/v3 v3.0.0 h1:CaJBQvQCOWoftrBcDt7Nwgo0kdpmrKxar/x2o6pV9JA= 15 | github.com/hashicorp/go-set/v3 v3.0.0/go.mod h1:IEghM2MpE5IaNvL+D7X480dfNtxjRXZ6VMpK3C8s2ok= 16 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 17 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 18 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 19 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 20 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 21 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 22 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 23 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 24 | github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= 25 | github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= 26 | github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= 27 | github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= 28 | github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE= 29 | github.com/miekg/dns v1.1.66/go.mod h1:jGFzBsSNbJw6z1HYut1RKBKHA9PBdxeHrZG8J+gC2WE= 30 | github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= 31 | github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= 32 | github.com/oleiade/lane/v2 v2.0.0 h1:XW/ex/Inr+bPkLd3O240xrFOhUkTd4Wy176+Gv0E3Qw= 33 | github.com/oleiade/lane/v2 v2.0.0/go.mod h1:i5FBPFAYSWCgLh58UkUGCChjcCzef/MI7PlQm2TKCeg= 34 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 35 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 36 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 37 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 38 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 39 | github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= 40 | github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= 41 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 42 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 43 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 44 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 45 | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 46 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= 47 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 48 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 49 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 50 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 51 | github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= 52 | github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 53 | github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= 54 | github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= 55 | golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 56 | golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= 57 | golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= 58 | golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw= 59 | golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= 60 | golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= 61 | golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= 62 | golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= 63 | golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= 64 | golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= 65 | golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 66 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 67 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= 68 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 69 | golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= 70 | golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= 71 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 72 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 73 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 74 | gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= 75 | gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0= 76 | gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= 77 | gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= 78 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 79 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 80 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 81 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 82 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 83 | --------------------------------------------------------------------------------