├── .gitignore ├── examples └── date │ ├── .gitignore │ ├── publisher.go │ └── README.md ├── internal ├── wire │ ├── stream_type.go │ ├── object_status.go │ ├── fetch_header_message.go │ ├── error.go │ ├── location.go │ ├── announce_ok_message.go │ ├── unannounce_message.go │ ├── unsubscribe_message.go │ ├── fetch_cancel_message.go │ ├── varint_bytes.go │ ├── max_request_id_message.go │ ├── subscribe_announces_ok_message.go │ ├── requests_blocked_message.go │ ├── unsubscribe_announces_message.go │ ├── go_away_message.go │ ├── stream_header_subgroup_message.go │ ├── announce_error_message.go │ ├── fetch_error_message.go │ ├── server_setup_message.go │ ├── subscribe_announces_error_message.go │ ├── publish_error_message.go │ ├── subscribe_error_message.go │ ├── announce_cancel_message.go │ ├── announce_message.go │ ├── track_status_message.go │ ├── subscribe_announces_message.go │ ├── tuple.go │ ├── client_setup_message.go │ ├── unsubscribe_message_test.go │ ├── go_away_message_test.go │ ├── control_message_parser_test.go │ ├── subscribe_done_message.go │ ├── track_status_request_message.go │ ├── announce_ok_message_test.go │ ├── fetch_ok_message.go │ ├── kvp_list_test.go │ ├── tuple_test.go │ ├── announce_message_test.go │ ├── subscribe_error_message_test.go │ ├── track_status_message_test.go │ ├── token.go │ ├── version.go │ ├── unannounce_message_test.go │ ├── version_test.go │ ├── subscribe_update_message.go │ ├── announce_error_message_test.go │ ├── publish_ok_message.go │ ├── publish_message.go │ ├── stream_header_subgroup_message_test.go │ ├── track_status_request_message_test.go │ ├── key_value_pair.go │ ├── announce_cancel_message_test.go │ ├── varint_bytes_test.go │ ├── control_message_parser.go │ ├── kvp_list.go │ ├── subscribe_ok_message.go │ ├── key_value_pair_test.go │ ├── object_datagram_message.go │ ├── object_message.go │ ├── subscribe_update_message_test.go │ ├── server_setup_message_test.go │ ├── client_setup_message_test.go │ ├── control_message_type.go │ ├── subscribe_done_message_test.go │ └── fetch_message.go └── slices │ └── slices.go ├── goleak_test.go ├── announcement.go ├── announcement_subscription.go ├── integrationtests ├── handshake_test.go ├── subscribe_announces_test.go ├── announce_test.go ├── fetch_test.go └── utils.go ├── track_status_request.go ├── sequence.go ├── object.go ├── announcement_response_writer.go ├── constants.go ├── announcement_subscription_response_writer.go ├── fetch_response_writer.go ├── quicmoq ├── receive_stream.go ├── send_stream.go ├── stream.go └── connection.go ├── track_status_request_map.go ├── webtransportmoq ├── receive_stream.go ├── send_stream.go ├── stream.go └── connection.go ├── go.mod ├── .github └── workflows │ └── go.yml ├── logging.go ├── track_status_response_writer.go ├── request_id.go ├── CONTRIBUTING.md ├── LICENSE ├── control_stream.go ├── announcement_subscription_map.go ├── local_track_map.go ├── announcement_map.go ├── session_helpers.go ├── mockgen.go ├── README.md ├── fetch_stream.go ├── mock_handler_test.go ├── messages_test.go ├── subgroup.go ├── subscribe_response_writer.go ├── remote_track_map.go ├── go.sum ├── connection.go ├── mock_control_message_stream_test.go ├── handler.go └── local_track.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.pem 2 | .idea/ 3 | -------------------------------------------------------------------------------- /examples/date/.gitignore: -------------------------------------------------------------------------------- 1 | date 2 | -------------------------------------------------------------------------------- /internal/wire/stream_type.go: -------------------------------------------------------------------------------- 1 | package wire 2 | -------------------------------------------------------------------------------- /goleak_test.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "testing" 5 | 6 | "go.uber.org/goleak" 7 | ) 8 | 9 | func TestMain(m *testing.M) { 10 | goleak.VerifyTestMain(m) 11 | } 12 | -------------------------------------------------------------------------------- /announcement.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import "github.com/mengelbart/moqtransport/internal/wire" 4 | 5 | type announcement struct { 6 | requestID uint64 7 | namespace []string 8 | parameters wire.KVPList 9 | 10 | response chan error 11 | } 12 | -------------------------------------------------------------------------------- /announcement_subscription.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type announcementSubscriptionResponse struct { 4 | err error 5 | } 6 | 7 | type announcementSubscription struct { 8 | requestID uint64 9 | namespace []string 10 | response chan announcementSubscriptionResponse 11 | } 12 | -------------------------------------------------------------------------------- /integrationtests/handshake_test.go: -------------------------------------------------------------------------------- 1 | package integrationtests 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestHandshake(t *testing.T) { 8 | sConn, cConn, cancel := connect(t) 9 | defer cancel() 10 | 11 | _, _, cancel = setup(t, sConn, cConn, nil) 12 | defer cancel() 13 | } 14 | -------------------------------------------------------------------------------- /internal/wire/object_status.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | type ObjectStatus int 4 | 5 | const ( 6 | ObjectStatusNormal ObjectStatus = 0x00 7 | ObjectStatusObjectDoesNotExist ObjectStatus = 0x01 8 | ObjectStatusEndOfGroup ObjectStatus = 0x03 9 | ObjectStatusEndOfTrack ObjectStatus = 0x04 10 | ) 11 | -------------------------------------------------------------------------------- /track_status_request.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type TrackStatus struct { 4 | Namespace []string 5 | Trackname string 6 | StatusCode uint64 7 | LastGroupID uint64 8 | LastObjectID uint64 9 | } 10 | 11 | type trackStatusRequest struct { 12 | requestID uint64 13 | namespace []string 14 | trackname string 15 | 16 | response chan *TrackStatus 17 | } 18 | -------------------------------------------------------------------------------- /sequence.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import "sync/atomic" 4 | 5 | type sequence struct { 6 | last atomic.Uint64 7 | interval uint64 8 | } 9 | 10 | func newSequence(initial, interval uint64) *sequence { 11 | s := &sequence{ 12 | last: atomic.Uint64{}, 13 | interval: interval, 14 | } 15 | s.last.Store(initial) 16 | return s 17 | } 18 | 19 | func (s *sequence) next() uint64 { 20 | return s.last.Add(s.interval) - s.interval 21 | } 22 | -------------------------------------------------------------------------------- /object.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type ObjectForwardingPreference int 4 | 5 | const ( 6 | ObjectForwardingPreferenceSubgroup ObjectForwardingPreference = 0x00 7 | ObjectForwardingPreferenceDatagram 8 | ) 9 | 10 | // An Object is a MoQ Object. 11 | type Object struct { 12 | GroupID uint64 13 | ObjectID uint64 14 | ForwardingPreference ObjectForwardingPreference 15 | SubGroupID uint64 16 | Payload []byte 17 | } 18 | -------------------------------------------------------------------------------- /announcement_response_writer.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type announcementResponseWriter struct { 4 | requestID uint64 5 | session *Session 6 | handled bool 7 | } 8 | 9 | func (a *announcementResponseWriter) Accept() error { 10 | a.handled = true 11 | return a.session.acceptAnnouncement(a.requestID) 12 | } 13 | 14 | func (a *announcementResponseWriter) Reject(code uint64, reason string) error { 15 | a.handled = true 16 | return a.session.rejectAnnouncement(a.requestID, code, reason) 17 | } 18 | -------------------------------------------------------------------------------- /internal/wire/fetch_header_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "github.com/quic-go/quic-go/quicvarint" 5 | ) 6 | 7 | type FetchHeaderMessage struct { 8 | RequestID uint64 9 | } 10 | 11 | func (m *FetchHeaderMessage) Append(buf []byte) []byte { 12 | buf = quicvarint.Append(buf, uint64(StreamTypeFetch)) 13 | return quicvarint.Append(buf, m.RequestID) 14 | } 15 | 16 | func (m *FetchHeaderMessage) parse(reader messageReader) (err error) { 17 | m.RequestID, err = quicvarint.Read(reader) 18 | if err != nil { 19 | return 20 | } 21 | return 22 | } 23 | -------------------------------------------------------------------------------- /internal/wire/error.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import "errors" 4 | 5 | var ( 6 | errInvalidMessageType = errors.New("invalid message type") 7 | errInvalidFilterType = errors.New("invalid filter type") 8 | errInvalidContentExistsByte = errors.New("invalid use of ContentExists byte") 9 | errInvalidGroupOrder = errors.New("invalid GroupOrder") 10 | errInvalidForwardFlag = errors.New("invalid Forward flag") 11 | errLengthMismatch = errors.New("length mismatch") 12 | errInvalidFetchType = errors.New("invalid fetch type") 13 | ) 14 | -------------------------------------------------------------------------------- /constants.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | const ( 4 | SubscribeStatusUnsubscribed = 0x00 5 | SubscribeStatusInternalError = 0x01 6 | SubscribeStatusUnauthorized = 0x02 7 | SubscribeStatusTrackEnded = 0x03 8 | SubscribeStatusSubscriptionEnded = 0x04 9 | SubscribeStatusGoingAway = 0x05 10 | SubscribeStatusExpired = 0x06 11 | ) 12 | 13 | const ( 14 | TrackStatusInProgress = 0x00 15 | TrackStatusDoesNotExist = 0x01 16 | TrackStatusNotYetBegun = 0x02 17 | TrackStatusFinished = 0x03 18 | TrackStatusUnavailable = 0x04 19 | ) 20 | -------------------------------------------------------------------------------- /announcement_subscription_response_writer.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type announcementSubscriptionResponseWriter struct { 4 | requestID uint64 5 | session *Session 6 | handled bool 7 | } 8 | 9 | func (a *announcementSubscriptionResponseWriter) Accept() error { 10 | a.handled = true 11 | return a.session.acceptAnnouncementSubscription(a.requestID) 12 | } 13 | 14 | func (a *announcementSubscriptionResponseWriter) Reject(code uint64, reason string) error { 15 | a.handled = true 16 | return a.session.rejectAnnouncementSubscription(a.requestID, code, reason) 17 | } 18 | -------------------------------------------------------------------------------- /internal/slices/slices.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "iter" 5 | "slices" 6 | ) 7 | 8 | func Collect[E any](seq iter.Seq[E]) []E { 9 | return slices.Collect(seq) 10 | } 11 | 12 | func Backward[Slice ~[]E, E any](s Slice) iter.Seq2[int, E] { 13 | return slices.Backward(s) 14 | } 15 | 16 | func Contains[S ~[]E, E comparable](s S, v E) bool { 17 | return slices.Contains(s, v) 18 | } 19 | 20 | func Map[K any, V any](ee []K, f func(e K) V) iter.Seq[V] { 21 | return func(yield func(V) bool) { 22 | for _, v := range ee { 23 | if !yield(f(v)) { 24 | return 25 | } 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /internal/wire/location.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import "github.com/quic-go/quic-go/quicvarint" 4 | 5 | type Location struct { 6 | Group uint64 7 | Object uint64 8 | } 9 | 10 | func (l Location) append(buf []byte) []byte { 11 | buf = quicvarint.Append(buf, l.Group) 12 | return quicvarint.Append(buf, l.Object) 13 | } 14 | 15 | func (l *Location) parse(_ Version, data []byte) (n int, err error) { 16 | l.Group, n, err = quicvarint.Parse(data) 17 | if err != nil { 18 | return n, err 19 | } 20 | data = data[n:] 21 | var m int 22 | l.Object, m, err = quicvarint.Parse(data) 23 | return n + m, err 24 | } 25 | -------------------------------------------------------------------------------- /fetch_response_writer.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type fetchResponseWriter struct { 4 | id uint64 5 | session *Session 6 | localTrack *localTrack 7 | handled bool 8 | } 9 | 10 | // Accept implements ResponseWriter. 11 | func (f *fetchResponseWriter) Accept() error { 12 | f.handled = true 13 | return f.session.acceptFetch(f.id) 14 | } 15 | 16 | // Reject implements ResponseWriter. 17 | func (f *fetchResponseWriter) Reject(code uint64, reason string) error { 18 | f.handled = true 19 | return f.session.rejectFetch(f.id, code, reason) 20 | } 21 | 22 | func (f *fetchResponseWriter) FetchStream() (*FetchStream, error) { 23 | return f.localTrack.getFetchStream() 24 | } 25 | -------------------------------------------------------------------------------- /internal/wire/announce_ok_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type AnnounceOkMessage struct { 10 | RequestID uint64 11 | } 12 | 13 | func (m *AnnounceOkMessage) LogValue() slog.Value { 14 | return slog.GroupValue( 15 | slog.String("type", "announce_ok"), 16 | ) 17 | } 18 | 19 | func (m AnnounceOkMessage) Type() controlMessageType { 20 | return messageTypeAnnounceOk 21 | } 22 | 23 | func (m *AnnounceOkMessage) Append(buf []byte) []byte { 24 | return quicvarint.Append(buf, m.RequestID) 25 | } 26 | 27 | func (m *AnnounceOkMessage) parse(_ Version, data []byte) (err error) { 28 | m.RequestID, _, err = quicvarint.Parse(data) 29 | return err 30 | } 31 | -------------------------------------------------------------------------------- /internal/wire/unannounce_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | ) 6 | 7 | type UnannounceMessage struct { 8 | TrackNamespace Tuple 9 | } 10 | 11 | func (m *UnannounceMessage) LogValue() slog.Value { 12 | return slog.GroupValue( 13 | slog.String("type", "unannounce"), 14 | slog.Any("track_namespace", m.TrackNamespace), 15 | ) 16 | } 17 | 18 | func (m UnannounceMessage) Type() controlMessageType { 19 | return messageTypeUnannounce 20 | } 21 | 22 | func (m *UnannounceMessage) Append(buf []byte) []byte { 23 | buf = m.TrackNamespace.append(buf) 24 | return buf 25 | } 26 | 27 | func (p *UnannounceMessage) parse(_ Version, data []byte) (err error) { 28 | p.TrackNamespace, _, err = parseTuple(data) 29 | return err 30 | } 31 | -------------------------------------------------------------------------------- /quicmoq/receive_stream.go: -------------------------------------------------------------------------------- 1 | package quicmoq 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport" 5 | "github.com/quic-go/quic-go" 6 | ) 7 | 8 | var _ moqtransport.ReceiveStream = (*ReceiveStream)(nil) 9 | 10 | type ReceiveStream struct { 11 | stream *quic.ReceiveStream 12 | } 13 | 14 | // Read implements moqtransport.ReceiveStream. 15 | func (r *ReceiveStream) Read(p []byte) (n int, err error) { 16 | return r.stream.Read(p) 17 | } 18 | 19 | // Stop implements moqtransport.ReceiveStream. 20 | func (r *ReceiveStream) Stop(code uint32) { 21 | r.stream.CancelRead(quic.StreamErrorCode(code)) 22 | } 23 | 24 | // StreamID implements moqtransport.ReceiveStream 25 | func (r *ReceiveStream) StreamID() uint64 { 26 | return uint64(r.stream.StreamID()) 27 | } 28 | -------------------------------------------------------------------------------- /track_status_request_map.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type trackStatusRequestMap struct { 8 | lock sync.Mutex 9 | requests map[uint64]*trackStatusRequest 10 | } 11 | 12 | func newTrackStatusRequestMap() *trackStatusRequestMap { 13 | return &trackStatusRequestMap{ 14 | lock: sync.Mutex{}, 15 | requests: map[uint64]*trackStatusRequest{}, 16 | } 17 | } 18 | 19 | func (m *trackStatusRequestMap) add(tsr *trackStatusRequest) { 20 | m.lock.Lock() 21 | defer m.lock.Unlock() 22 | m.requests[tsr.requestID] = tsr 23 | } 24 | 25 | func (m *trackStatusRequestMap) delete(requestID uint64) (*trackStatusRequest, bool) { 26 | m.lock.Lock() 27 | defer m.lock.Unlock() 28 | tsr, ok := m.requests[requestID] 29 | return tsr, ok 30 | } 31 | -------------------------------------------------------------------------------- /webtransportmoq/receive_stream.go: -------------------------------------------------------------------------------- 1 | package webtransportmoq 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport" 5 | "github.com/quic-go/webtransport-go" 6 | ) 7 | 8 | var _ moqtransport.ReceiveStream = (*ReceiveStream)(nil) 9 | 10 | type ReceiveStream struct { 11 | stream *webtransport.ReceiveStream 12 | } 13 | 14 | // Read implements moqtransport.ReceiveStream. 15 | func (r *ReceiveStream) Read(p []byte) (n int, err error) { 16 | return r.stream.Read(p) 17 | } 18 | 19 | // Stop implements moqtransport.ReceiveStream. 20 | func (r *ReceiveStream) Stop(code uint32) { 21 | r.stream.CancelRead(webtransport.StreamErrorCode(code)) 22 | } 23 | 24 | // StreamID implements moqtransport.ReceiveStream 25 | func (r *ReceiveStream) StreamID() uint64 { 26 | return uint64(r.stream.StreamID()) 27 | } 28 | -------------------------------------------------------------------------------- /examples/date/publisher.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/mengelbart/moqtransport" 7 | ) 8 | 9 | type publisher struct { 10 | p moqtransport.Publisher 11 | sessionID uint64 12 | requestID uint64 13 | } 14 | 15 | func (p *publisher) SendDatagram(o moqtransport.Object) error { 16 | return p.p.SendDatagram(o) 17 | } 18 | 19 | func (p *publisher) OpenSubgroup(groupID, subgroupID uint64, priority uint8) (*moqtransport.Subgroup, error) { 20 | log.Printf("sessionNr: %d, requestID: %d, groupID: %d, subgroupID: %v", 21 | p.sessionID, p.requestID, groupID, subgroupID) 22 | return p.p.OpenSubgroup(groupID, subgroupID, priority) 23 | } 24 | 25 | func (p *publisher) CloseWithError(code uint64, reason string) error { 26 | return p.p.CloseWithError(code, reason) 27 | } 28 | -------------------------------------------------------------------------------- /internal/wire/unsubscribe_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type UnsubscribeMessage struct { 10 | RequestID uint64 11 | } 12 | 13 | func (m *UnsubscribeMessage) LogValue() slog.Value { 14 | return slog.GroupValue( 15 | slog.String("type", "unsubscribe"), 16 | slog.Uint64("request_id", m.RequestID), 17 | ) 18 | } 19 | 20 | func (m UnsubscribeMessage) Type() controlMessageType { 21 | return messageTypeUnsubscribe 22 | } 23 | 24 | func (m *UnsubscribeMessage) Append(buf []byte) []byte { 25 | buf = quicvarint.Append(buf, m.RequestID) 26 | return buf 27 | } 28 | 29 | func (m *UnsubscribeMessage) parse(_ Version, data []byte) (err error) { 30 | m.RequestID, _, err = quicvarint.Parse(data) 31 | return err 32 | } 33 | -------------------------------------------------------------------------------- /internal/wire/fetch_cancel_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type FetchCancelMessage struct { 11 | RequestID uint64 12 | } 13 | 14 | func (m *FetchCancelMessage) LogValue() slog.Value { 15 | return slog.GroupValue( 16 | slog.String("type", "fetch_cancel"), 17 | slog.Uint64("request_id", m.RequestID), 18 | ) 19 | } 20 | 21 | func (m FetchCancelMessage) Type() controlMessageType { 22 | return messageTypeFetchCancel 23 | } 24 | 25 | func (m *FetchCancelMessage) Append(buf []byte) []byte { 26 | return quicvarint.Append(buf, m.RequestID) 27 | } 28 | 29 | func (m *FetchCancelMessage) parse(_ Version, data []byte) (err error) { 30 | m.RequestID, _, err = quicvarint.Parse(data) 31 | return err 32 | } 33 | -------------------------------------------------------------------------------- /internal/wire/varint_bytes.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | func appendVarIntBytes(buf []byte, data []byte) []byte { 10 | buf = quicvarint.Append(buf, uint64(len(data))) 11 | buf = append(buf, data...) 12 | return buf 13 | } 14 | 15 | func varIntBytesLen(s string) uint64 { 16 | return uint64(quicvarint.Len(uint64(len(s)))) + uint64(len(s)) 17 | } 18 | 19 | func parseVarIntBytes(data []byte) ([]byte, int, error) { 20 | l, n, err := quicvarint.Parse(data) 21 | if err != nil { 22 | return []byte{}, n, err 23 | } 24 | 25 | if l == 0 { 26 | return []byte{}, n, nil 27 | } 28 | data = data[n:] 29 | 30 | if len(data) < int(l) { 31 | return []byte{}, n + len(data), io.ErrUnexpectedEOF 32 | } 33 | return data[:l], n + int(l), nil 34 | } 35 | -------------------------------------------------------------------------------- /internal/wire/max_request_id_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type MaxRequestIDMessage struct { 11 | RequestID uint64 12 | } 13 | 14 | func (m *MaxRequestIDMessage) LogValue() slog.Value { 15 | return slog.GroupValue( 16 | slog.String("type", "max_request_id"), 17 | slog.Uint64("max_request_id", m.RequestID), 18 | ) 19 | } 20 | 21 | func (m MaxRequestIDMessage) Type() controlMessageType { 22 | return messageTypeMaxRequestID 23 | } 24 | 25 | func (m *MaxRequestIDMessage) Append(buf []byte) []byte { 26 | return quicvarint.Append(buf, m.RequestID) 27 | } 28 | 29 | func (m *MaxRequestIDMessage) parse(_ Version, data []byte) (err error) { 30 | m.RequestID, _, err = quicvarint.Parse(data) 31 | return err 32 | } 33 | -------------------------------------------------------------------------------- /integrationtests/subscribe_announces_test.go: -------------------------------------------------------------------------------- 1 | package integrationtests 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/mengelbart/moqtransport" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSubscribeAnnounces(t *testing.T) { 12 | t.Run("success", func(t *testing.T) { 13 | sConn, cConn, cancel := connect(t) 14 | defer cancel() 15 | 16 | handler := moqtransport.HandlerFunc(func(w moqtransport.ResponseWriter, m *moqtransport.Message) { 17 | assert.Equal(t, moqtransport.MessageSubscribeAnnounces, m.Method) 18 | assert.NotNil(t, w) 19 | assert.NoError(t, w.Accept()) 20 | }) 21 | _, ct, cancel := setup(t, sConn, cConn, handler) 22 | defer cancel() 23 | 24 | err := ct.SubscribeAnnouncements(context.Background(), []string{"test_prefix"}) 25 | assert.NoError(t, err) 26 | }) 27 | } 28 | -------------------------------------------------------------------------------- /internal/wire/subscribe_announces_ok_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type SubscribeAnnouncesOkMessage struct { 11 | RequestID uint64 12 | } 13 | 14 | func (m *SubscribeAnnouncesOkMessage) LogValue() slog.Value { 15 | return slog.GroupValue( 16 | slog.String("type", "subscribe_announces_ok"), 17 | ) 18 | } 19 | 20 | func (m SubscribeAnnouncesOkMessage) Type() controlMessageType { 21 | return messageTypeSubscribeNamespaceOk 22 | } 23 | 24 | func (m *SubscribeAnnouncesOkMessage) Append(buf []byte) []byte { 25 | return quicvarint.Append(buf, m.RequestID) 26 | } 27 | 28 | func (m *SubscribeAnnouncesOkMessage) parse(_ Version, data []byte) (err error) { 29 | m.RequestID, _, err = quicvarint.Parse(data) 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /internal/wire/requests_blocked_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type RequestsBlockedMessage struct { 10 | MaximumRequestID uint64 11 | } 12 | 13 | func (m *RequestsBlockedMessage) LogValue() slog.Value { 14 | return slog.GroupValue( 15 | slog.String("type", "requests_blocked"), 16 | slog.Uint64("max_request_id", m.MaximumRequestID), 17 | ) 18 | } 19 | 20 | func (m RequestsBlockedMessage) Type() controlMessageType { 21 | return messageTypeRequestsBlocked 22 | } 23 | 24 | func (m *RequestsBlockedMessage) Append(buf []byte) []byte { 25 | return quicvarint.Append(buf, m.MaximumRequestID) 26 | } 27 | 28 | func (m *RequestsBlockedMessage) parse(_ Version, data []byte) (err error) { 29 | m.MaximumRequestID, _, err = quicvarint.Parse(data) 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /quicmoq/send_stream.go: -------------------------------------------------------------------------------- 1 | package quicmoq 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport" 5 | "github.com/quic-go/quic-go" 6 | ) 7 | 8 | var _ moqtransport.SendStream = (*SendStream)(nil) 9 | 10 | type SendStream struct { 11 | stream *quic.SendStream 12 | } 13 | 14 | // Write implements moqtransport.SendStream. 15 | func (s *SendStream) Write(p []byte) (n int, err error) { 16 | return s.stream.Write(p) 17 | } 18 | 19 | // Reset implements moqtransport.SendStream 20 | func (s *SendStream) Reset(code uint32) { 21 | s.stream.CancelWrite(quic.StreamErrorCode(code)) 22 | } 23 | 24 | // Close implements moqtransport.SendStream. 25 | func (s *SendStream) Close() error { 26 | return s.stream.Close() 27 | } 28 | 29 | // StreamID implements moqtransport.SendStream 30 | func (s *SendStream) StreamID() uint64 { 31 | return uint64(s.stream.StreamID()) 32 | } 33 | -------------------------------------------------------------------------------- /internal/wire/unsubscribe_announces_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | ) 6 | 7 | // TODO: Add tests 8 | type UnsubscribeAnnouncesMessage struct { 9 | TrackNamespacePrefix Tuple 10 | } 11 | 12 | func (m *UnsubscribeAnnouncesMessage) LogValue() slog.Value { 13 | return slog.GroupValue( 14 | slog.String("type", "unsubscribe_announces"), 15 | slog.Any("track_namespace_prefix", m.TrackNamespacePrefix), 16 | ) 17 | } 18 | 19 | func (m UnsubscribeAnnouncesMessage) Type() controlMessageType { 20 | return messageTypeUnsubscribeNamespace 21 | } 22 | 23 | func (m *UnsubscribeAnnouncesMessage) Append(buf []byte) []byte { 24 | return m.TrackNamespacePrefix.append(buf) 25 | } 26 | 27 | func (m *UnsubscribeAnnouncesMessage) parse(_ Version, data []byte) (err error) { 28 | m.TrackNamespacePrefix, _, err = parseTuple(data) 29 | return err 30 | } 31 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mengelbart/moqtransport 2 | 3 | go 1.23.6 4 | 5 | require ( 6 | github.com/mengelbart/qlog v0.1.0 7 | github.com/quic-go/quic-go v0.53.0 8 | github.com/quic-go/webtransport-go v0.9.0 9 | github.com/stretchr/testify v1.10.0 10 | go.uber.org/goleak v1.3.0 11 | go.uber.org/mock v0.5.0 12 | golang.org/x/sync v0.12.0 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/kr/text v0.2.0 // indirect 18 | github.com/pmezard/go-difflib v1.0.0 // indirect 19 | github.com/quic-go/qpack v0.5.1 // indirect 20 | golang.org/x/crypto v0.36.0 // indirect 21 | golang.org/x/mod v0.22.0 // indirect 22 | golang.org/x/net v0.38.0 // indirect 23 | golang.org/x/sys v0.31.0 // indirect 24 | golang.org/x/text v0.23.0 // indirect 25 | golang.org/x/tools v0.29.0 // indirect 26 | gopkg.in/yaml.v3 v3.0.1 // indirect 27 | ) 28 | -------------------------------------------------------------------------------- /webtransportmoq/send_stream.go: -------------------------------------------------------------------------------- 1 | package webtransportmoq 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport" 5 | "github.com/quic-go/webtransport-go" 6 | ) 7 | 8 | var _ moqtransport.SendStream = (*SendStream)(nil) 9 | 10 | type SendStream struct { 11 | stream *webtransport.SendStream 12 | } 13 | 14 | // Write implements moqtransport.SendStream. 15 | func (s *SendStream) Write(p []byte) (n int, err error) { 16 | return s.stream.Write(p) 17 | } 18 | 19 | // Reset implements moqtransport.SendStream 20 | func (s *SendStream) Reset(code uint32) { 21 | s.stream.CancelWrite(webtransport.StreamErrorCode(code)) 22 | } 23 | 24 | // Close implements moqtransport.SendStream. 25 | func (s *SendStream) Close() error { 26 | return s.stream.Close() 27 | } 28 | 29 | // StreamID implements moqtransport.SendStream 30 | func (s *SendStream) StreamID() uint64 { 31 | return uint64(s.stream.StreamID()) 32 | } 33 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | - main 6 | pull_request: 7 | 8 | jobs: 9 | golangci: 10 | runs-on: ubuntu-latest 11 | name: lint 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-go@v5 15 | with: 16 | go-version: 'stable' 17 | cache: false 18 | - name: golangci-lint 19 | uses: golangci/golangci-lint-action@v6 20 | with: 21 | version: v1.64 22 | test: 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | go: [ '1.23.x', '1.24.x' ] 27 | runs-on: ubuntu-latest 28 | name: Unit tests (${{ matrix.go }}) 29 | steps: 30 | - uses: actions/checkout@v4 31 | - uses: actions/setup-go@v5 32 | with: 33 | go-version: ${{ matrix.go }} 34 | - name: Build 35 | run: go build -v ./... 36 | - name: Test 37 | run: go test -v -race ./... 38 | -------------------------------------------------------------------------------- /logging.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "log/slog" 5 | "math" 6 | "os" 7 | "strings" 8 | ) 9 | 10 | const logEnv = "MOQ_LOG_LEVEL" 11 | 12 | const ( 13 | LogLevelNone = math.MaxInt 14 | ) 15 | 16 | var moqtransportLogLevel = new(slog.LevelVar) 17 | 18 | var defaultLogger *slog.Logger 19 | 20 | func init() { 21 | h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ 22 | AddSource: false, 23 | Level: moqtransportLogLevel, 24 | }) 25 | moqtransportLogLevel.Set(readLoggingEnv()) 26 | defaultLogger = slog.New(h) 27 | } 28 | 29 | func readLoggingEnv() slog.Level { 30 | switch strings.ToLower(os.Getenv(logEnv)) { 31 | case "": 32 | return LogLevelNone 33 | case "debug": 34 | return slog.LevelDebug 35 | case "info": 36 | return slog.LevelInfo 37 | case "warn": 38 | return slog.LevelWarn 39 | case "error": 40 | return slog.LevelError 41 | default: 42 | return LogLevelNone 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /track_status_response_writer.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | type trackStatusResponseWriter struct { 4 | session *Session 5 | handled bool 6 | status TrackStatus 7 | } 8 | 9 | // Accept commits the status and sends a response to the peer. 10 | func (w *trackStatusResponseWriter) Accept() error { 11 | w.handled = true 12 | return w.session.sendTrackStatus(w.status) 13 | } 14 | 15 | // Reject sends a track does not exist status 16 | func (w *trackStatusResponseWriter) Reject(uint64, string) error { 17 | w.handled = true 18 | w.status.StatusCode = TrackStatusDoesNotExist 19 | w.status.LastGroupID = 0 20 | w.status.LastObjectID = 0 21 | return w.Accept() 22 | } 23 | 24 | // SetStatus implements StatusRequestHandler. 25 | func (w *trackStatusResponseWriter) SetStatus(statusCode uint64, lastGroupID uint64, lastObjectID uint64) { 26 | w.status.StatusCode = statusCode 27 | w.status.LastGroupID = lastGroupID 28 | w.status.LastObjectID = lastObjectID 29 | } 30 | -------------------------------------------------------------------------------- /internal/wire/go_away_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/mengelbart/qlog" 7 | ) 8 | 9 | type GoAwayMessage struct { 10 | NewSessionURI string 11 | } 12 | 13 | func (m *GoAwayMessage) LogValue() slog.Value { 14 | return slog.GroupValue( 15 | slog.String("type", "goaway"), 16 | slog.Any("new_session_uri", qlog.RawInfo{ 17 | Length: uint64(len(m.NewSessionURI)), 18 | PayloadLength: uint64(len(m.NewSessionURI)), 19 | Data: []byte(m.NewSessionURI), 20 | }), 21 | ) 22 | } 23 | 24 | func (m GoAwayMessage) Type() controlMessageType { 25 | return messageTypeGoAway 26 | } 27 | 28 | func (m *GoAwayMessage) Append(buf []byte) []byte { 29 | buf = appendVarIntBytes(buf, []byte(m.NewSessionURI)) 30 | return buf 31 | } 32 | 33 | func (m *GoAwayMessage) parse(_ Version, data []byte) (err error) { 34 | newSessionURI, _, err := parseVarIntBytes(data) 35 | m.NewSessionURI = string(newSessionURI) 36 | return err 37 | } 38 | -------------------------------------------------------------------------------- /request_id.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | var errRequestIDblocked = errors.New("request IDs blocked") 9 | 10 | type requestIDGenerator struct { 11 | lock sync.Mutex 12 | id uint64 13 | max uint64 14 | interval uint64 15 | } 16 | 17 | func newRequestIDGenerator(initialID, maxID, interval uint64) *requestIDGenerator { 18 | return &requestIDGenerator{ 19 | id: initialID, 20 | max: maxID, 21 | interval: interval, 22 | } 23 | } 24 | 25 | func (g *requestIDGenerator) next() (uint64, error) { 26 | g.lock.Lock() 27 | defer g.lock.Unlock() 28 | if g.id >= g.max { 29 | return g.max, errRequestIDblocked 30 | } 31 | next := g.id 32 | g.id += g.interval 33 | return next, nil 34 | } 35 | 36 | func (g *requestIDGenerator) setMax(v uint64) error { 37 | g.lock.Lock() 38 | defer g.lock.Unlock() 39 | if v < g.max { 40 | return errMaxRequestIDDecreased 41 | } 42 | g.max = v 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Pull Requests 4 | 5 | Contributions are welcome. Please consider adding an issue to discuss the 6 | proposed change before putting in a lot of work. 7 | 8 | ## Commit Messages 9 | 10 | Commit messages should follow [these guidelines](https://git-scm.com/book/en/v2/Distributed-Git-Contributing-to-a-Project#_commit_guidelines): 11 | 12 | * The subject line must be capitalized 13 | * The subject line must be written in the imperative 14 | * The subject line must not end in a period 15 | * The subject line should not exceed about 50 characters 16 | * A body can optionally be added to explain the change in more detail 17 | * The body must be separeted from the subject line with a blank line 18 | * The body should be wrapped at 72 characters 19 | 20 | Example: 21 | 22 | ``` 23 | Add guidelines for contributing 24 | 25 | Add a contributing.md file to document guidelines for contributing to 26 | moqtransport. The documentation contains guidelines for pull requests 27 | and commit messages and should be extended in the future. 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /internal/wire/stream_header_subgroup_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "github.com/quic-go/quic-go/quicvarint" 5 | ) 6 | 7 | type SubgroupHeaderMessage struct { 8 | TrackAlias uint64 9 | GroupID uint64 10 | SubgroupID uint64 11 | PublisherPriority uint8 12 | } 13 | 14 | func (m *SubgroupHeaderMessage) Append(buf []byte) []byte { 15 | buf = quicvarint.Append(buf, uint64(StreamTypeSubgroupSIDExt)) 16 | buf = quicvarint.Append(buf, m.TrackAlias) 17 | buf = quicvarint.Append(buf, m.GroupID) 18 | buf = quicvarint.Append(buf, m.SubgroupID) 19 | return append(buf, m.PublisherPriority) 20 | } 21 | 22 | func (m *SubgroupHeaderMessage) parse(reader messageReader, sid bool) (err error) { 23 | m.TrackAlias, err = quicvarint.Read(reader) 24 | if err != nil { 25 | return 26 | } 27 | m.GroupID, err = quicvarint.Read(reader) 28 | if err != nil { 29 | return 30 | } 31 | if sid { 32 | m.SubgroupID, err = quicvarint.Read(reader) 33 | if err != nil { 34 | return 35 | } 36 | } 37 | m.PublisherPriority, err = reader.ReadByte() 38 | return 39 | } 40 | -------------------------------------------------------------------------------- /quicmoq/stream.go: -------------------------------------------------------------------------------- 1 | package quicmoq 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport" 5 | "github.com/quic-go/quic-go" 6 | ) 7 | 8 | var _ moqtransport.Stream = (*Stream)(nil) 9 | 10 | type Stream struct { 11 | stream *quic.Stream 12 | } 13 | 14 | // Read implements moqtransport.Stream. 15 | func (s *Stream) Read(p []byte) (n int, err error) { 16 | return s.stream.Read(p) 17 | } 18 | 19 | // Write implements moqtransport.Stream. 20 | func (s *Stream) Write(p []byte) (n int, err error) { 21 | return s.stream.Write(p) 22 | } 23 | 24 | // Close implements moqtransport.Stream. 25 | func (s *Stream) Close() error { 26 | return s.stream.Close() 27 | } 28 | 29 | // Reset implements moqtransport.Stream. 30 | func (s *Stream) Reset(code uint32) { 31 | s.stream.CancelWrite(quic.StreamErrorCode(code)) 32 | } 33 | 34 | // Stop implements moqtransport.Stream. 35 | func (s *Stream) Stop(code uint32) { 36 | s.stream.CancelRead(quic.StreamErrorCode(code)) 37 | } 38 | 39 | // StreamID implements moqtransport.Stream. 40 | func (s *Stream) StreamID() uint64 { 41 | return uint64(s.stream.StreamID()) 42 | } 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mathis Engelbart 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /webtransportmoq/stream.go: -------------------------------------------------------------------------------- 1 | package webtransportmoq 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport" 5 | "github.com/quic-go/webtransport-go" 6 | ) 7 | 8 | var _ moqtransport.Stream = (*Stream)(nil) 9 | 10 | type Stream struct { 11 | stream *webtransport.Stream 12 | } 13 | 14 | // Read implements moqtransport.Stream. 15 | func (s *Stream) Read(p []byte) (n int, err error) { 16 | return s.stream.Read(p) 17 | } 18 | 19 | // Write implements moqtransport.Stream. 20 | func (s *Stream) Write(p []byte) (n int, err error) { 21 | return s.stream.Write(p) 22 | } 23 | 24 | // Close implements moqtransport.Stream. 25 | func (s *Stream) Close() error { 26 | return s.stream.Close() 27 | } 28 | 29 | // Reset implements moqtransport.Stream. 30 | func (s *Stream) Reset(code uint32) { 31 | s.stream.CancelWrite(webtransport.StreamErrorCode(code)) 32 | } 33 | 34 | // Stop implements moqtransport.Stream. 35 | func (s *Stream) Stop(code uint32) { 36 | s.stream.CancelRead(webtransport.StreamErrorCode(code)) 37 | } 38 | 39 | // StreamID implements moqtransport.Stream. 40 | func (s *Stream) StreamID() uint64 { 41 | return uint64(s.stream.StreamID()) 42 | } 43 | -------------------------------------------------------------------------------- /control_stream.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "iter" 5 | "log/slog" 6 | 7 | "github.com/mengelbart/moqtransport/internal/wire" 8 | "github.com/mengelbart/qlog" 9 | "github.com/mengelbart/qlog/moqt" 10 | ) 11 | 12 | type controlStream struct { 13 | stream Stream 14 | logger *slog.Logger 15 | qlogger *qlog.Logger 16 | } 17 | 18 | func (s *controlStream) read() iter.Seq2[wire.ControlMessage, error] { 19 | parser := wire.NewControlMessageParser(s.stream) 20 | return func(yield func(wire.ControlMessage, error) bool) { 21 | for { 22 | msg, err := parser.Parse() 23 | if !yield(msg, err) { 24 | return 25 | } 26 | } 27 | } 28 | } 29 | 30 | func (s *controlStream) write(msg wire.ControlMessage) error { 31 | buf, err := compileMessage(msg) 32 | if err != nil { 33 | return err 34 | } 35 | if s.qlogger != nil { 36 | s.qlogger.Log(moqt.ControlMessageEvent{ 37 | EventName: moqt.ControlMessageEventCreated, 38 | StreamID: s.stream.StreamID(), 39 | Length: uint64(len(buf)), 40 | Message: msg, 41 | }) 42 | } 43 | s.logger.Info("sending message", "type", msg.Type().String(), "msg", msg) 44 | _, err = s.stream.Write(buf) 45 | if err != nil { 46 | return err 47 | } 48 | return nil 49 | } 50 | -------------------------------------------------------------------------------- /internal/wire/announce_error_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type AnnounceErrorMessage struct { 10 | RequestID uint64 11 | ErrorCode uint64 12 | ReasonPhrase string 13 | } 14 | 15 | func (m *AnnounceErrorMessage) LogValue() slog.Value { 16 | return slog.GroupValue( 17 | slog.String("type", "announce_error"), 18 | slog.Uint64("error_code", m.ErrorCode), 19 | slog.String("reason", m.ReasonPhrase), 20 | ) 21 | } 22 | 23 | func (m AnnounceErrorMessage) Type() controlMessageType { 24 | return messageTypeAnnounceError 25 | } 26 | 27 | func (m *AnnounceErrorMessage) Append(buf []byte) []byte { 28 | buf = quicvarint.Append(buf, m.RequestID) 29 | buf = quicvarint.Append(buf, m.ErrorCode) 30 | buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 31 | return buf 32 | } 33 | 34 | func (m *AnnounceErrorMessage) parse(_ Version, data []byte) (err error) { 35 | var n int 36 | m.RequestID, n, err = quicvarint.Parse(data) 37 | if err != nil { 38 | return err 39 | } 40 | data = data[n:] 41 | 42 | m.ErrorCode, n, err = quicvarint.Parse(data) 43 | if err != nil { 44 | return err 45 | } 46 | data = data[n:] 47 | 48 | reasonPhrase, _, err := parseVarIntBytes(data) 49 | m.ReasonPhrase = string(reasonPhrase) 50 | return err 51 | } 52 | -------------------------------------------------------------------------------- /integrationtests/announce_test.go: -------------------------------------------------------------------------------- 1 | package integrationtests 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/mengelbart/moqtransport" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAnnounce(t *testing.T) { 12 | t.Run("success", func(t *testing.T) { 13 | sConn, cConn, cancel := connect(t) 14 | defer cancel() 15 | 16 | handler := moqtransport.HandlerFunc(func(w moqtransport.ResponseWriter, m *moqtransport.Message) { 17 | assert.Equal(t, moqtransport.MessageAnnounce, m.Method) 18 | assert.NotNil(t, w) 19 | assert.NoError(t, w.Accept()) 20 | }) 21 | _, ct, cancel := setup(t, sConn, cConn, handler) 22 | defer cancel() 23 | 24 | err := ct.Announce(context.Background(), []string{"namespace"}) 25 | assert.NoError(t, err) 26 | }) 27 | t.Run("error", func(t *testing.T) { 28 | sConn, cConn, cancel := connect(t) 29 | defer cancel() 30 | 31 | handler := moqtransport.HandlerFunc(func(w moqtransport.ResponseWriter, m *moqtransport.Message) { 32 | assert.Equal(t, moqtransport.MessageAnnounce, m.Method) 33 | assert.NotNil(t, w) 34 | assert.NoError(t, w.Reject(uint64(moqtransport.ErrorCodeAnnounceInternal), "expected error")) 35 | }) 36 | _, ct, cancel := setup(t, sConn, cConn, handler) 37 | defer cancel() 38 | 39 | err := ct.Announce(context.Background(), []string{"namespace"}) 40 | assert.Error(t, err) 41 | }) 42 | } 43 | -------------------------------------------------------------------------------- /internal/wire/fetch_error_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type FetchErrorMessage struct { 11 | RequestID uint64 12 | ErrorCode uint64 13 | ReasonPhrase string 14 | } 15 | 16 | func (m *FetchErrorMessage) LogValue() slog.Value { 17 | return slog.GroupValue( 18 | slog.String("type", "fetch_error"), 19 | slog.Uint64("request_id", m.RequestID), 20 | slog.Uint64("error_code", m.ErrorCode), 21 | slog.String("reason", m.ReasonPhrase), 22 | ) 23 | } 24 | 25 | func (m FetchErrorMessage) Type() controlMessageType { 26 | return messageTypeFetchError 27 | } 28 | 29 | func (m *FetchErrorMessage) Append(buf []byte) []byte { 30 | buf = quicvarint.Append(buf, m.RequestID) 31 | buf = quicvarint.Append(buf, m.ErrorCode) 32 | return appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 33 | } 34 | 35 | func (m *FetchErrorMessage) parse(_ Version, data []byte) (err error) { 36 | var n int 37 | m.RequestID, n, err = quicvarint.Parse(data) 38 | if err != nil { 39 | return err 40 | } 41 | data = data[n:] 42 | 43 | m.ErrorCode, n, err = quicvarint.Parse(data) 44 | if err != nil { 45 | return err 46 | } 47 | data = data[n:] 48 | 49 | reasonPhrase, _, err := parseVarIntBytes(data) 50 | if err != nil { 51 | return err 52 | } 53 | m.ReasonPhrase = string(reasonPhrase) 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /internal/wire/server_setup_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type ServerSetupMessage struct { 10 | SelectedVersion Version 11 | SetupParameters KVPList 12 | } 13 | 14 | func (m *ServerSetupMessage) LogValue() slog.Value { 15 | attrs := []slog.Attr{ 16 | slog.String("type", "server_setup"), 17 | slog.Uint64("selected_version", uint64(m.SelectedVersion)), 18 | slog.Uint64("number_of_parameters", uint64(len(m.SetupParameters))), 19 | } 20 | if len(m.SetupParameters) > 0 { 21 | attrs = append(attrs, 22 | slog.Any("setup_parameters", m.SetupParameters), 23 | ) 24 | } 25 | return slog.GroupValue(attrs...) 26 | } 27 | 28 | func (m ServerSetupMessage) Type() controlMessageType { 29 | return messageTypeServerSetup 30 | } 31 | 32 | func (m *ServerSetupMessage) Append(buf []byte) []byte { 33 | buf = quicvarint.Append(buf, uint64(m.SelectedVersion)) 34 | buf = quicvarint.Append(buf, uint64(len(m.SetupParameters))) 35 | for _, p := range m.SetupParameters { 36 | buf = p.append(buf) 37 | } 38 | return buf 39 | } 40 | 41 | func (m *ServerSetupMessage) parse(_ Version, data []byte) error { 42 | sv, n, err := quicvarint.Parse(data) 43 | if err != nil { 44 | return err 45 | } 46 | data = data[n:] 47 | 48 | m.SelectedVersion = Version(sv) 49 | m.SetupParameters = KVPList{} 50 | return m.SetupParameters.parseNum(data) 51 | } 52 | -------------------------------------------------------------------------------- /internal/wire/subscribe_announces_error_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type SubscribeAnnouncesErrorMessage struct { 11 | RequestID uint64 12 | ErrorCode uint64 13 | ReasonPhrase string 14 | } 15 | 16 | func (m *SubscribeAnnouncesErrorMessage) LogValue() slog.Value { 17 | return slog.GroupValue( 18 | slog.String("type", "subscribe_announces_error"), 19 | slog.Uint64("error_code", m.ErrorCode), 20 | slog.String("reason", m.ReasonPhrase), 21 | ) 22 | } 23 | 24 | func (m SubscribeAnnouncesErrorMessage) Type() controlMessageType { 25 | return messageTypeSubscribeNamespaceError 26 | } 27 | 28 | func (m *SubscribeAnnouncesErrorMessage) Append(buf []byte) []byte { 29 | buf = quicvarint.Append(buf, m.RequestID) 30 | buf = quicvarint.Append(buf, m.ErrorCode) 31 | return appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 32 | } 33 | 34 | func (m *SubscribeAnnouncesErrorMessage) parse(_ Version, data []byte) (err error) { 35 | var n int 36 | m.RequestID, n, err = quicvarint.Parse(data) 37 | if err != nil { 38 | return err 39 | } 40 | data = data[n:] 41 | 42 | m.ErrorCode, n, err = quicvarint.Parse(data) 43 | if err != nil { 44 | return err 45 | } 46 | data = data[n:] 47 | 48 | reasonPhrase, _, err := parseVarIntBytes(data) 49 | if err != nil { 50 | return err 51 | } 52 | m.ReasonPhrase = string(reasonPhrase) 53 | return nil 54 | } 55 | -------------------------------------------------------------------------------- /internal/wire/publish_error_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type PublishErrorMessage struct { 10 | RequestID uint64 11 | ErrorCode uint64 12 | ReasonPhrase string 13 | } 14 | 15 | func (m *PublishErrorMessage) LogValue() slog.Value { 16 | return slog.GroupValue( 17 | slog.String("type", "subscribe_error"), 18 | slog.Uint64("request_id", m.RequestID), 19 | slog.Uint64("error_code", m.ErrorCode), 20 | slog.String("reason", m.ReasonPhrase), 21 | slog.Any("reason_bytes", []byte(m.ReasonPhrase)), 22 | ) 23 | } 24 | 25 | func (m PublishErrorMessage) Type() controlMessageType { 26 | return messageTypeSubscribeError 27 | } 28 | 29 | func (m *PublishErrorMessage) Append(buf []byte) []byte { 30 | buf = quicvarint.Append(buf, m.RequestID) 31 | buf = quicvarint.Append(buf, uint64(m.ErrorCode)) 32 | buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 33 | return buf 34 | } 35 | 36 | func (m *PublishErrorMessage) parse(_ Version, data []byte) (err error) { 37 | var n int 38 | m.RequestID, n, err = quicvarint.Parse(data) 39 | if err != nil { 40 | return err 41 | } 42 | data = data[n:] 43 | 44 | m.ErrorCode, n, err = quicvarint.Parse(data) 45 | if err != nil { 46 | return err 47 | } 48 | data = data[n:] 49 | 50 | reasonPhrase, _, err := parseVarIntBytes(data) 51 | if err != nil { 52 | return err 53 | } 54 | m.ReasonPhrase = string(reasonPhrase) 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /internal/wire/subscribe_error_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type SubscribeErrorMessage struct { 10 | RequestID uint64 11 | ErrorCode uint64 12 | ReasonPhrase string 13 | } 14 | 15 | func (m *SubscribeErrorMessage) LogValue() slog.Value { 16 | return slog.GroupValue( 17 | slog.String("type", "subscribe_error"), 18 | slog.Uint64("request_id", m.RequestID), 19 | slog.Uint64("error_code", m.ErrorCode), 20 | slog.String("reason", m.ReasonPhrase), 21 | slog.Any("reason_bytes", []byte(m.ReasonPhrase)), 22 | ) 23 | } 24 | 25 | func (m SubscribeErrorMessage) Type() controlMessageType { 26 | return messageTypeSubscribeError 27 | } 28 | 29 | func (m *SubscribeErrorMessage) Append(buf []byte) []byte { 30 | buf = quicvarint.Append(buf, m.RequestID) 31 | buf = quicvarint.Append(buf, uint64(m.ErrorCode)) 32 | buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 33 | return buf 34 | } 35 | 36 | func (m *SubscribeErrorMessage) parse(_ Version, data []byte) (err error) { 37 | var n int 38 | m.RequestID, n, err = quicvarint.Parse(data) 39 | if err != nil { 40 | return err 41 | } 42 | data = data[n:] 43 | 44 | m.ErrorCode, n, err = quicvarint.Parse(data) 45 | if err != nil { 46 | return err 47 | } 48 | data = data[n:] 49 | 50 | reasonPhrase, _, err := parseVarIntBytes(data) 51 | if err != nil { 52 | return err 53 | } 54 | m.ReasonPhrase = string(reasonPhrase) 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /internal/wire/announce_cancel_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type AnnounceCancelMessage struct { 10 | TrackNamespace Tuple 11 | ErrorCode uint64 12 | ReasonPhrase string 13 | } 14 | 15 | func (m *AnnounceCancelMessage) LogValue() slog.Value { 16 | return slog.GroupValue( 17 | slog.String("type", "announce_cancel"), 18 | slog.Any("track_namespace", m.TrackNamespace), 19 | slog.Uint64("error_code", m.ErrorCode), 20 | slog.String("reason", m.ReasonPhrase), 21 | ) 22 | } 23 | 24 | func (m AnnounceCancelMessage) GetTrackNamespace() string { 25 | return m.TrackNamespace.String() 26 | } 27 | 28 | func (m AnnounceCancelMessage) Type() controlMessageType { 29 | return messageTypeAnnounce 30 | } 31 | 32 | func (m *AnnounceCancelMessage) Append(buf []byte) []byte { 33 | buf = m.TrackNamespace.append(buf) 34 | buf = quicvarint.Append(buf, m.ErrorCode) 35 | buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 36 | return buf 37 | } 38 | 39 | func (m *AnnounceCancelMessage) parse(_ Version, data []byte) (err error) { 40 | var n int 41 | m.TrackNamespace, n, err = parseTuple(data) 42 | if err != nil { 43 | return err 44 | } 45 | data = data[n:] 46 | 47 | m.ErrorCode, n, err = quicvarint.Parse(data) 48 | if err != nil { 49 | return err 50 | } 51 | data = data[n:] 52 | 53 | reasonPhrase, _, err := parseVarIntBytes(data) 54 | m.ReasonPhrase = string(reasonPhrase) 55 | return err 56 | } 57 | -------------------------------------------------------------------------------- /internal/wire/announce_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type AnnounceMessage struct { 10 | RequestID uint64 11 | TrackNamespace Tuple 12 | Parameters KVPList 13 | } 14 | 15 | func (m *AnnounceMessage) LogValue() slog.Value { 16 | attrs := []slog.Attr{ 17 | slog.String("type", "announce"), 18 | slog.Any("track_namespace", m.TrackNamespace), 19 | slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), 20 | } 21 | if len(m.Parameters) > 0 { 22 | attrs = append(attrs, 23 | slog.Any("parameters", m.Parameters), 24 | ) 25 | } 26 | return slog.GroupValue(attrs...) 27 | } 28 | 29 | func (m AnnounceMessage) GetTrackNamespace() string { 30 | return m.TrackNamespace.String() 31 | } 32 | 33 | func (m AnnounceMessage) Type() controlMessageType { 34 | return messageTypeAnnounce 35 | } 36 | 37 | func (m *AnnounceMessage) Append(buf []byte) []byte { 38 | buf = quicvarint.Append(buf, m.RequestID) 39 | buf = m.TrackNamespace.append(buf) 40 | return m.Parameters.appendNum(buf) 41 | } 42 | 43 | func (m *AnnounceMessage) parse(_ Version, data []byte) (err error) { 44 | var n int 45 | m.RequestID, n, err = quicvarint.Parse(data) 46 | if err != nil { 47 | return err 48 | } 49 | data = data[n:] 50 | 51 | m.TrackNamespace, n, err = parseTuple(data) 52 | if err != nil { 53 | return err 54 | } 55 | data = data[n:] 56 | 57 | m.Parameters = KVPList{} 58 | return m.Parameters.parseNum(data) 59 | } 60 | -------------------------------------------------------------------------------- /internal/wire/track_status_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type TrackStatusMessage struct { 10 | RequestID uint64 11 | StatusCode uint64 12 | LargestLocation Location 13 | Parameters KVPList 14 | } 15 | 16 | func (m *TrackStatusMessage) LogValue() slog.Value { 17 | return slog.GroupValue( 18 | slog.String("type", "track_status"), 19 | slog.Uint64("status_code", m.StatusCode), 20 | slog.Uint64("last_group_id", m.LargestLocation.Group), 21 | slog.Uint64("last_object_id", m.LargestLocation.Object), 22 | ) 23 | } 24 | 25 | func (m TrackStatusMessage) Type() controlMessageType { 26 | return messageTypeTrackStatusOk 27 | } 28 | 29 | func (m *TrackStatusMessage) Append(buf []byte) []byte { 30 | buf = quicvarint.Append(buf, m.RequestID) 31 | buf = quicvarint.Append(buf, m.StatusCode) 32 | buf = m.LargestLocation.append(buf) 33 | return m.Parameters.appendNum(buf) 34 | } 35 | 36 | func (m *TrackStatusMessage) parse(v Version, data []byte) (err error) { 37 | var n int 38 | m.RequestID, n, err = quicvarint.Parse(data) 39 | if err != nil { 40 | return 41 | } 42 | data = data[n:] 43 | 44 | m.StatusCode, n, err = quicvarint.Parse(data) 45 | if err != nil { 46 | return 47 | } 48 | data = data[n:] 49 | 50 | n, err = m.LargestLocation.parse(v, data) 51 | if err != nil { 52 | return 53 | } 54 | data = data[n:] 55 | 56 | m.Parameters = KVPList{} 57 | return m.Parameters.parseNum(data) 58 | } 59 | -------------------------------------------------------------------------------- /internal/wire/subscribe_announces_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type SubscribeAnnouncesMessage struct { 11 | RequestID uint64 12 | TrackNamespacePrefix Tuple 13 | Parameters KVPList 14 | } 15 | 16 | func (m *SubscribeAnnouncesMessage) LogValue() slog.Value { 17 | attrs := []slog.Attr{ 18 | slog.String("type", "subscribe_announces"), 19 | slog.Any("track_namespace_prefix", m.TrackNamespacePrefix), 20 | slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), 21 | } 22 | if len(m.Parameters) > 0 { 23 | attrs = append(attrs, 24 | slog.Any("parameters", m.Parameters), 25 | ) 26 | } 27 | return slog.GroupValue(attrs...) 28 | } 29 | 30 | func (m SubscribeAnnouncesMessage) Type() controlMessageType { 31 | return messageTypeSubscribeNamespace 32 | } 33 | 34 | func (m *SubscribeAnnouncesMessage) Append(buf []byte) []byte { 35 | buf = quicvarint.Append(buf, m.RequestID) 36 | buf = m.TrackNamespacePrefix.append(buf) 37 | return m.Parameters.appendNum(buf) 38 | } 39 | 40 | func (m *SubscribeAnnouncesMessage) parse(_ Version, data []byte) (err error) { 41 | var n int 42 | m.RequestID, n, err = quicvarint.Parse(data) 43 | if err != nil { 44 | return err 45 | } 46 | data = data[n:] 47 | 48 | m.TrackNamespacePrefix, n, err = parseTuple(data) 49 | if err != nil { 50 | return err 51 | } 52 | data = data[n:] 53 | 54 | m.Parameters = KVPList{} 55 | return m.Parameters.parseNum(data) 56 | } 57 | -------------------------------------------------------------------------------- /announcement_subscription_map.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "slices" 5 | "sync" 6 | ) 7 | 8 | type announcementSubscriptionMap struct { 9 | lock sync.Mutex 10 | as map[uint64]*announcementSubscription 11 | } 12 | 13 | func newAnnouncementSubscriptionMap() *announcementSubscriptionMap { 14 | return &announcementSubscriptionMap{ 15 | lock: sync.Mutex{}, 16 | as: make(map[uint64]*announcementSubscription), 17 | } 18 | } 19 | 20 | func findAnnouncementSubscription(as map[uint64]*announcementSubscription, namespace []string) *announcementSubscription { 21 | for _, v := range as { 22 | if slices.Equal(namespace, v.namespace) { 23 | return v 24 | } 25 | } 26 | return nil 27 | } 28 | 29 | func (m *announcementSubscriptionMap) add(a *announcementSubscription) { 30 | m.lock.Lock() 31 | defer m.lock.Unlock() 32 | m.as[a.requestID] = a 33 | } 34 | 35 | // delete returns the deleted element (if present) and whether the entry was 36 | // present and removed. 37 | func (m *announcementSubscriptionMap) delete(namespace []string) (*announcementSubscription, bool) { 38 | m.lock.Lock() 39 | defer m.lock.Unlock() 40 | as := findAnnouncementSubscription(m.as, namespace) 41 | if as != nil { 42 | delete(m.as, as.requestID) 43 | return as, true 44 | } 45 | return nil, false 46 | } 47 | 48 | func (m *announcementSubscriptionMap) deleteByID(requestID uint64) (*announcementSubscription, bool) { 49 | m.lock.Lock() 50 | defer m.lock.Unlock() 51 | as, ok := m.as[requestID] 52 | delete(m.as, requestID) 53 | return as, ok 54 | } 55 | -------------------------------------------------------------------------------- /internal/wire/tuple.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/mengelbart/moqtransport/internal/slices" 9 | "github.com/quic-go/quic-go/quicvarint" 10 | ) 11 | 12 | type Tuple []string 13 | 14 | func (t Tuple) append(buf []byte) []byte { 15 | buf = quicvarint.Append(buf, uint64(len(t))) 16 | for _, t := range t { 17 | buf = quicvarint.Append(buf, uint64(len(t))) 18 | buf = append(buf, t...) 19 | } 20 | return buf 21 | } 22 | 23 | func (t Tuple) MarshalJSON() ([]byte, error) { 24 | elements := slices.Collect(slices.Map(t, func(s string) string { 25 | return fmt.Sprintf(`{"value": "%v"}`, s) 26 | })) 27 | return []byte(json.RawMessage("[" + strings.Join(elements, ",") + "]")), nil 28 | } 29 | 30 | func (t Tuple) String() string { 31 | res := "" 32 | for _, t := range t { 33 | res += string(t) 34 | } 35 | return res 36 | } 37 | 38 | func parseTuple(data []byte) (Tuple, int, error) { 39 | length, parsed, err := quicvarint.Parse(data) 40 | if err != nil { 41 | return nil, parsed, err 42 | } 43 | data = data[parsed:] 44 | 45 | tuple := make([]string, 0, length) 46 | for i := uint64(0); i < length; i++ { 47 | l, n, err := quicvarint.Parse(data) 48 | parsed += n 49 | if err != nil { 50 | return tuple, parsed, err 51 | } 52 | data = data[n:] 53 | 54 | if uint64(len(data)) < l { 55 | return tuple, parsed, errLengthMismatch 56 | } 57 | tuple = append(tuple, string(data[:l])) 58 | data = data[l:] 59 | parsed += int(l) 60 | } 61 | return tuple, parsed, nil 62 | } 63 | -------------------------------------------------------------------------------- /internal/wire/client_setup_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type ClientSetupMessage struct { 10 | SupportedVersions versions 11 | SetupParameters KVPList 12 | } 13 | 14 | func (m *ClientSetupMessage) LogValue() slog.Value { 15 | attrs := []slog.Attr{ 16 | slog.String("type", "client_setup"), 17 | slog.Uint64("number_of_supported_versions", uint64(len(m.SupportedVersions))), 18 | slog.Any("supported_versions", m.SupportedVersions), 19 | slog.Uint64("number_of_parameters", uint64(len(m.SetupParameters))), 20 | } 21 | if len(m.SetupParameters) > 0 { 22 | attrs = append(attrs, 23 | slog.Any("setup_parameters", m.SetupParameters), 24 | ) 25 | } 26 | return slog.GroupValue(attrs...) 27 | } 28 | 29 | func (m ClientSetupMessage) Type() controlMessageType { 30 | return messageTypeClientSetup 31 | } 32 | 33 | func (m *ClientSetupMessage) Append(buf []byte) []byte { 34 | buf = quicvarint.Append(buf, uint64(len(m.SupportedVersions))) 35 | for _, v := range m.SupportedVersions { 36 | buf = quicvarint.Append(buf, uint64(v)) 37 | } 38 | buf = quicvarint.Append(buf, uint64(len(m.SetupParameters))) 39 | for _, p := range m.SetupParameters { 40 | buf = p.append(buf) 41 | } 42 | return buf 43 | } 44 | 45 | func (m *ClientSetupMessage) parse(_ Version, data []byte) error { 46 | n, err := m.SupportedVersions.parse(data) 47 | if err != nil { 48 | return err 49 | } 50 | data = data[n:] 51 | m.SetupParameters = KVPList{} 52 | return m.SetupParameters.parseNum(data) 53 | } 54 | -------------------------------------------------------------------------------- /internal/wire/unsubscribe_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestUnsubscribeMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | usm UnsubscribeMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | usm: UnsubscribeMessage{ 19 | RequestID: 17, 20 | }, 21 | buf: []byte{}, 22 | expect: []byte{ 23 | 0x11, 24 | }, 25 | }, 26 | { 27 | usm: UnsubscribeMessage{ 28 | RequestID: 17, 29 | }, 30 | buf: []byte{0x0a, 0x0b}, 31 | expect: []byte{0x0a, 0x0b, 0x11}, 32 | }, 33 | } 34 | for i, tc := range cases { 35 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 36 | res := tc.usm.Append(tc.buf) 37 | assert.Equal(t, tc.expect, res) 38 | }) 39 | } 40 | } 41 | 42 | func TestParseUnsubscribeMessage(t *testing.T) { 43 | cases := []struct { 44 | data []byte 45 | expect *UnsubscribeMessage 46 | err error 47 | }{ 48 | { 49 | data: nil, 50 | expect: &UnsubscribeMessage{}, 51 | err: io.EOF, 52 | }, 53 | { 54 | data: []byte{17}, 55 | expect: &UnsubscribeMessage{ 56 | RequestID: 17, 57 | }, 58 | err: nil, 59 | }, 60 | } 61 | for i, tc := range cases { 62 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 63 | res := &UnsubscribeMessage{} 64 | err := res.parse(CurrentVersion, tc.data) 65 | assert.Equal(t, tc.expect, res) 66 | if tc.err != nil { 67 | assert.Equal(t, tc.err, err) 68 | } else { 69 | assert.NoError(t, err) 70 | } 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /internal/wire/go_away_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestGoAwayMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | gam GoAwayMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | gam: GoAwayMessage{ 19 | NewSessionURI: "", 20 | }, 21 | buf: []byte{}, 22 | expect: []byte{ 23 | 0x00, 24 | }, 25 | }, 26 | { 27 | gam: GoAwayMessage{ 28 | NewSessionURI: "uri", 29 | }, 30 | buf: []byte{0x0a, 0x0b}, 31 | expect: []byte{ 32 | 0x0a, 0x0b, 0x03, 'u', 'r', 'i', 33 | }, 34 | }, 35 | } 36 | for i, tc := range cases { 37 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 38 | res := tc.gam.Append(tc.buf) 39 | assert.Equal(t, tc.expect, res) 40 | }) 41 | } 42 | } 43 | 44 | func TestParseGoAwayMessage(t *testing.T) { 45 | cases := []struct { 46 | data []byte 47 | expect *GoAwayMessage 48 | err error 49 | }{ 50 | { 51 | data: nil, 52 | expect: &GoAwayMessage{}, 53 | err: io.EOF, 54 | }, 55 | { 56 | data: append([]byte{0x03}, "uri"...), 57 | expect: &GoAwayMessage{ 58 | NewSessionURI: "uri", 59 | }, 60 | err: nil, 61 | }, 62 | } 63 | for i, tc := range cases { 64 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 65 | res := &GoAwayMessage{} 66 | err := res.parse(CurrentVersion, tc.data) 67 | assert.Equal(t, tc.expect, res) 68 | if tc.err != nil { 69 | assert.Equal(t, tc.err, err) 70 | } else { 71 | assert.NoError(t, err) 72 | } 73 | }) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /internal/wire/control_message_parser_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type mockReader struct { 12 | reads [][]byte 13 | index int 14 | } 15 | 16 | func (r *mockReader) Read(p []byte) (int, error) { 17 | if r.index == len(r.reads) { 18 | return 0, io.EOF 19 | } 20 | n := copy(p, r.reads[r.index]) 21 | r.index += 1 22 | return n, nil 23 | } 24 | 25 | func TestControlMessageParser(t *testing.T) { 26 | cases := []struct { 27 | mr *mockReader 28 | expect ControlMessage 29 | err error 30 | }{ 31 | { 32 | mr: &mockReader{ 33 | reads: [][]byte{ 34 | {0x40, byte(messageTypeClientSetup), 0x00, 0x04, 0x02, 0x00, 0x01, 0x00}, 35 | }, 36 | index: 0, 37 | }, 38 | expect: &ClientSetupMessage{ 39 | SupportedVersions: []Version{0x00, 0x01}, 40 | SetupParameters: KVPList{}, 41 | }, 42 | err: nil, 43 | }, 44 | { 45 | mr: &mockReader{ 46 | reads: [][]byte{ 47 | {0x40, byte(messageTypeClientSetup), 0x00, 0x04}, 48 | {0x02, 0x00, 0x01, 0x00}, 49 | }, 50 | index: 0, 51 | }, 52 | expect: &ClientSetupMessage{ 53 | SupportedVersions: []Version{0x00, 0x01}, 54 | SetupParameters: KVPList{}, 55 | }, 56 | err: nil, 57 | }, 58 | } 59 | for i, tc := range cases { 60 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 61 | p := NewControlMessageParser(tc.mr) 62 | m, err := p.Parse() 63 | assert.Equal(t, tc.expect, m) 64 | if tc.err != nil { 65 | assert.Equal(t, tc.err, err) 66 | } else { 67 | assert.NoError(t, err) 68 | } 69 | }) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /internal/wire/subscribe_done_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type SubscribeDoneMessage struct { 10 | RequestID uint64 11 | StatusCode uint64 12 | StreamCount uint64 13 | ReasonPhrase string 14 | } 15 | 16 | func (m *SubscribeDoneMessage) LogValue() slog.Value { 17 | return slog.GroupValue( 18 | slog.String("type", "subscribe_done"), 19 | slog.Uint64("request_id", m.RequestID), 20 | slog.Uint64("status_code", m.StatusCode), 21 | slog.Uint64("stream_count", m.StreamCount), 22 | slog.String("reason", m.ReasonPhrase), 23 | ) 24 | } 25 | 26 | func (m SubscribeDoneMessage) Type() controlMessageType { 27 | return messageTypeSubscribeDone 28 | } 29 | 30 | func (m *SubscribeDoneMessage) Append(buf []byte) []byte { 31 | buf = quicvarint.Append(buf, m.RequestID) 32 | buf = quicvarint.Append(buf, m.StatusCode) 33 | buf = quicvarint.Append(buf, m.StreamCount) 34 | buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) 35 | return buf 36 | } 37 | 38 | func (m *SubscribeDoneMessage) parse(_ Version, data []byte) (err error) { 39 | var n int 40 | m.RequestID, n, err = quicvarint.Parse(data) 41 | if err != nil { 42 | return 43 | } 44 | data = data[n:] 45 | 46 | m.StatusCode, n, err = quicvarint.Parse(data) 47 | if err != nil { 48 | return 49 | } 50 | data = data[n:] 51 | 52 | m.StreamCount, n, err = quicvarint.Parse(data) 53 | if err != nil { 54 | return 55 | } 56 | data = data[n:] 57 | 58 | reasonPhrase, _, err := parseVarIntBytes(data) 59 | if err != nil { 60 | return 61 | } 62 | m.ReasonPhrase = string(reasonPhrase) 63 | return nil 64 | } 65 | -------------------------------------------------------------------------------- /internal/wire/track_status_request_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/mengelbart/qlog" 7 | "github.com/quic-go/quic-go/quicvarint" 8 | ) 9 | 10 | type TrackStatusRequestMessage struct { 11 | RequestID uint64 12 | TrackNamespace Tuple 13 | TrackName []byte 14 | Parameters KVPList 15 | } 16 | 17 | func (m *TrackStatusRequestMessage) LogValue() slog.Value { 18 | return slog.GroupValue( 19 | slog.String("type", "track_status_request"), 20 | slog.Any("track_namespace", m.TrackNamespace), 21 | slog.Any("track_name", qlog.RawInfo{ 22 | Length: uint64(len(m.TrackName)), 23 | PayloadLength: uint64(len(m.TrackName)), 24 | Data: []byte(m.TrackName), 25 | }), 26 | ) 27 | } 28 | 29 | func (m TrackStatusRequestMessage) Type() controlMessageType { 30 | return messageTypeTrackStatus 31 | } 32 | 33 | func (m *TrackStatusRequestMessage) Append(buf []byte) []byte { 34 | buf = quicvarint.Append(buf, m.RequestID) 35 | buf = m.TrackNamespace.append(buf) 36 | buf = appendVarIntBytes(buf, []byte(m.TrackName)) 37 | return m.Parameters.appendNum(buf) 38 | } 39 | 40 | func (m *TrackStatusRequestMessage) parse(_ Version, data []byte) (err error) { 41 | var n int 42 | m.RequestID, n, err = quicvarint.Parse(data) 43 | if err != nil { 44 | return 45 | } 46 | data = data[n:] 47 | 48 | m.TrackNamespace, n, err = parseTuple(data) 49 | if err != nil { 50 | return 51 | } 52 | data = data[n:] 53 | 54 | m.TrackName, n, err = parseVarIntBytes(data) 55 | if err != nil { 56 | return err 57 | } 58 | data = data[n:] 59 | 60 | m.Parameters = KVPList{} 61 | return m.Parameters.parseNum(data) 62 | } 63 | -------------------------------------------------------------------------------- /local_track_map.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type localTrackMap struct { 8 | lock sync.Mutex 9 | pending map[uint64]*localTrack 10 | open map[uint64]*localTrack 11 | } 12 | 13 | func newLocalTrackMap() *localTrackMap { 14 | return &localTrackMap{ 15 | lock: sync.Mutex{}, 16 | pending: map[uint64]*localTrack{}, 17 | open: map[uint64]*localTrack{}, 18 | } 19 | } 20 | 21 | func (m *localTrackMap) addPending(lt *localTrack) bool { 22 | m.lock.Lock() 23 | defer m.lock.Unlock() 24 | if _, ok := m.pending[lt.requestID]; ok { 25 | return false 26 | } 27 | if _, ok := m.open[lt.requestID]; ok { 28 | return false 29 | } 30 | m.pending[lt.requestID] = lt 31 | return true 32 | } 33 | 34 | func (m *localTrackMap) findByID(id uint64) (*localTrack, bool) { 35 | m.lock.Lock() 36 | defer m.lock.Unlock() 37 | sub, ok := m.pending[id] 38 | if !ok { 39 | sub, ok = m.open[id] 40 | } 41 | return sub, ok 42 | } 43 | 44 | func (m *localTrackMap) delete(id uint64) (*localTrack, bool) { 45 | m.lock.Lock() 46 | defer m.lock.Unlock() 47 | sub, ok := m.pending[id] 48 | if !ok { 49 | sub, ok = m.open[id] 50 | } 51 | if !ok { 52 | return nil, false 53 | } 54 | delete(m.pending, id) 55 | delete(m.open, id) 56 | return sub, true 57 | } 58 | 59 | func (m *localTrackMap) confirm(id uint64) (*localTrack, bool) { 60 | m.lock.Lock() 61 | defer m.lock.Unlock() 62 | lt, ok := m.pending[id] 63 | if !ok { 64 | return nil, false 65 | } 66 | delete(m.pending, id) 67 | m.open[id] = lt 68 | return lt, true 69 | } 70 | 71 | func (m *localTrackMap) reject(id uint64) (*localTrack, bool) { 72 | m.lock.Lock() 73 | defer m.lock.Unlock() 74 | lt, ok := m.pending[id] 75 | if !ok { 76 | return nil, false 77 | } 78 | delete(m.pending, id) 79 | return lt, true 80 | } 81 | -------------------------------------------------------------------------------- /internal/wire/announce_ok_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAnnounceOkMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | aom AnnounceOkMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | aom: AnnounceOkMessage{ 19 | RequestID: 1, 20 | }, 21 | buf: []byte{}, 22 | expect: []byte{ 23 | 0x01, 24 | }, 25 | }, 26 | { 27 | aom: AnnounceOkMessage{ 28 | RequestID: 1, 29 | }, 30 | buf: []byte{0x0a, 0x0b}, 31 | expect: []byte{0x0a, 0x0b, 0x01}, 32 | }, 33 | } 34 | for i, tc := range cases { 35 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 36 | res := tc.aom.Append(tc.buf) 37 | assert.Equal(t, tc.expect, res) 38 | }) 39 | } 40 | } 41 | 42 | func TestParseAnnounceOkMessage(t *testing.T) { 43 | cases := []struct { 44 | data []byte 45 | expect *AnnounceOkMessage 46 | err error 47 | }{ 48 | { 49 | data: nil, 50 | expect: &AnnounceOkMessage{}, 51 | err: io.EOF, 52 | }, 53 | { 54 | data: []byte{0x01}, 55 | expect: &AnnounceOkMessage{ 56 | RequestID: 1, 57 | }, 58 | err: nil, 59 | }, 60 | { 61 | data: []byte{0x01}, 62 | expect: &AnnounceOkMessage{ 63 | RequestID: 1, 64 | }, 65 | err: nil, 66 | }, 67 | { 68 | data: []byte{}, 69 | expect: &AnnounceOkMessage{ 70 | RequestID: 0, 71 | }, 72 | err: io.EOF, 73 | }, 74 | } 75 | for i, tc := range cases { 76 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 77 | res := &AnnounceOkMessage{} 78 | err := res.parse(CurrentVersion, tc.data) 79 | assert.Equal(t, tc.expect, res) 80 | if tc.err != nil { 81 | assert.Equal(t, tc.err, err) 82 | } else { 83 | assert.NoError(t, err) 84 | } 85 | }) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /announcement_map.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "slices" 5 | "sync" 6 | ) 7 | 8 | func findAnnouncement(as map[uint64]*announcement, namespace []string) *announcement { 9 | for _, v := range as { 10 | if slices.Equal(namespace, v.namespace) { 11 | return v 12 | } 13 | } 14 | return nil 15 | } 16 | 17 | type announcementMap struct { 18 | lock sync.Mutex 19 | pending map[uint64]*announcement 20 | announcements map[uint64]*announcement 21 | } 22 | 23 | func newAnnouncementMap() *announcementMap { 24 | return &announcementMap{ 25 | lock: sync.Mutex{}, 26 | pending: make(map[uint64]*announcement), 27 | announcements: make(map[uint64]*announcement), 28 | } 29 | } 30 | 31 | func (m *announcementMap) add(a *announcement) { 32 | m.lock.Lock() 33 | defer m.lock.Unlock() 34 | m.pending[a.requestID] = a 35 | } 36 | 37 | func (m *announcementMap) confirmAndGet(requestID uint64) (*announcement, error) { 38 | m.lock.Lock() 39 | defer m.lock.Unlock() 40 | a, ok := m.pending[requestID] 41 | if !ok { 42 | return nil, errUnknownAnnouncement 43 | } 44 | delete(m.pending, requestID) 45 | m.announcements[requestID] = a 46 | return a, nil 47 | } 48 | 49 | func (m *announcementMap) reject(requestID uint64) (*announcement, bool) { 50 | m.lock.Lock() 51 | defer m.lock.Unlock() 52 | a, ok := m.pending[requestID] 53 | if !ok { 54 | return nil, false 55 | } 56 | delete(m.pending, requestID) 57 | return a, true 58 | } 59 | 60 | func (m *announcementMap) delete(namespace []string) bool { 61 | m.lock.Lock() 62 | defer m.lock.Unlock() 63 | a := findAnnouncement(m.pending, namespace) 64 | if a != nil { 65 | delete(m.pending, a.requestID) 66 | return true 67 | } 68 | a = findAnnouncement(m.announcements, namespace) 69 | if a != nil { 70 | delete(m.announcements, a.requestID) 71 | return true 72 | } 73 | return false 74 | } 75 | -------------------------------------------------------------------------------- /session_helpers.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "math" 7 | "slices" 8 | 9 | "github.com/mengelbart/moqtransport/internal/wire" 10 | "github.com/quic-go/quic-go/quicvarint" 11 | ) 12 | 13 | var errControlMessageTooLarge = errors.New("control message too large") 14 | 15 | func compileMessage(msg wire.ControlMessage) ([]byte, error) { 16 | buf := make([]byte, 0, 4096) 17 | buf = quicvarint.Append(buf, uint64(msg.Type())) 18 | tl := len(buf) 19 | buf = append(buf, 0x00, 0x00) // length placeholder 20 | buf = msg.Append(buf) 21 | length := len(buf[tl+2:]) 22 | if length > math.MaxUint16 { 23 | return nil, errControlMessageTooLarge 24 | } 25 | binary.BigEndian.PutUint16(buf[tl:tl+2], uint16(length)) 26 | return buf, nil 27 | } 28 | 29 | func validatePathParameter(setupParameters wire.KVPList, protocolIsQUIC bool) (string, error) { 30 | index := slices.IndexFunc(setupParameters, func(p wire.KeyValuePair) bool { 31 | return p.Type == wire.PathParameterKey 32 | }) 33 | if index < 0 { 34 | if protocolIsQUIC { 35 | return "", errMissingPathParameter 36 | } 37 | return "", nil 38 | } 39 | if index > 0 && !protocolIsQUIC { 40 | return "", errUnexpectedPathParameter 41 | } 42 | return string(setupParameters[index].ValueBytes), nil 43 | } 44 | 45 | func getMaxRequestIDParameter(setupParameters wire.KVPList) uint64 { 46 | index := slices.IndexFunc(setupParameters, func(p wire.KeyValuePair) bool { 47 | return p.Type == wire.MaxRequestIDParameterKey 48 | }) 49 | if index < 0 { 50 | return 0 51 | } 52 | return setupParameters[index].ValueVarInt 53 | } 54 | 55 | func validateAuthParameter(subscribeParameters wire.KVPList) (string, error) { 56 | index := slices.IndexFunc(subscribeParameters, func(p wire.KeyValuePair) bool { 57 | return p.Type == wire.AuthorizationTokenParameterKey 58 | }) 59 | if index < 0 { 60 | return "", nil 61 | } 62 | return string(subscribeParameters[index].ValueBytes), nil 63 | } 64 | -------------------------------------------------------------------------------- /internal/wire/fetch_ok_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | // TODO: Add tests 10 | type FetchOkMessage struct { 11 | RequestID uint64 12 | GroupOrder uint8 13 | EndOfTrack uint8 14 | EndLocation Location 15 | SubscribeParameters KVPList 16 | } 17 | 18 | func (m *FetchOkMessage) LogValue() slog.Value { 19 | attrs := []slog.Attr{ 20 | slog.String("type", "fetch_ok"), 21 | slog.Uint64("request_id", m.RequestID), 22 | slog.Any("group_order", m.GroupOrder), 23 | slog.Any("end_of_track", m.EndOfTrack), 24 | slog.Uint64("largest_group_id", m.EndLocation.Group), 25 | slog.Uint64("largest_object_id", m.EndLocation.Object), 26 | slog.Uint64("number_of_parameters", uint64(len(m.SubscribeParameters))), 27 | } 28 | if len(m.SubscribeParameters) > 0 { 29 | attrs = append(attrs, 30 | slog.Any("subscribe_parameters", m.SubscribeParameters), 31 | ) 32 | } 33 | return slog.GroupValue(attrs...) 34 | } 35 | 36 | func (m FetchOkMessage) Type() controlMessageType { 37 | return messageTypeFetchOk 38 | } 39 | 40 | func (m *FetchOkMessage) Append(buf []byte) []byte { 41 | buf = quicvarint.Append(buf, m.RequestID) 42 | buf = append(buf, m.GroupOrder) 43 | buf = append(buf, m.EndOfTrack) 44 | buf = m.EndLocation.append(buf) 45 | return m.SubscribeParameters.appendNum(buf) 46 | } 47 | 48 | func (m *FetchOkMessage) parse(v Version, data []byte) (err error) { 49 | var n int 50 | m.RequestID, n, err = quicvarint.Parse(data) 51 | if err != nil { 52 | return err 53 | } 54 | data = data[n:] 55 | 56 | if len(data) < 2 { 57 | return errLengthMismatch 58 | } 59 | m.GroupOrder = data[0] 60 | if m.GroupOrder > 2 { 61 | return errInvalidGroupOrder 62 | } 63 | m.EndOfTrack = data[1] 64 | data = data[2:] 65 | 66 | n, err = m.EndLocation.parse(v, data) 67 | if err != nil { 68 | return err 69 | } 70 | data = data[n:] 71 | 72 | return m.SubscribeParameters.parseNum(data) 73 | } 74 | -------------------------------------------------------------------------------- /internal/wire/kvp_list_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestParseKVPList(t *testing.T) { 12 | cases := []struct { 13 | data []byte 14 | expect KVPList 15 | err error 16 | }{ 17 | { 18 | data: nil, 19 | expect: KVPList{}, 20 | err: io.EOF, 21 | }, 22 | { 23 | data: nil, 24 | expect: KVPList{}, 25 | err: io.EOF, 26 | }, 27 | { 28 | data: []byte{}, 29 | expect: KVPList{}, 30 | err: io.EOF, 31 | }, 32 | { 33 | data: []byte{0x01, 0x01, 0x01, 'A'}, 34 | expect: KVPList{KeyValuePair{ 35 | Type: 1, 36 | ValueBytes: []byte("A"), 37 | }}, 38 | err: nil, 39 | }, 40 | { 41 | data: []byte{0x02, 0x02, 0x03, 0x01, 0x01, 'A'}, 42 | expect: KVPList{ 43 | KeyValuePair{ 44 | Type: 2, 45 | ValueVarInt: uint64(3), 46 | }, 47 | KeyValuePair{ 48 | Type: 1, 49 | ValueBytes: []byte("A"), 50 | }, 51 | }, 52 | err: nil, 53 | }, 54 | { 55 | data: []byte{0x01, 0x01, 0x01, 'A', 0x02, 0x02, 0x02, 0x02}, 56 | expect: KVPList{KeyValuePair{ 57 | Type: 1, 58 | ValueBytes: []byte("A"), 59 | }}, 60 | err: nil, 61 | }, 62 | { 63 | data: []byte{}, 64 | expect: KVPList{}, 65 | err: io.EOF, 66 | }, 67 | { 68 | data: []byte{0x02, 0x0f, 0x01, 0x00, 0x01, 0x01, 'A'}, 69 | expect: KVPList{ 70 | KeyValuePair{ 71 | Type: 0x0f, 72 | ValueBytes: []byte{0x00}, 73 | }, 74 | KeyValuePair{ 75 | Type: PathParameterKey, 76 | ValueBytes: []byte("A"), 77 | }, 78 | }, 79 | err: nil, 80 | }, 81 | } 82 | for i, tc := range cases { 83 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 84 | res := KVPList{} 85 | err := res.parseNum(tc.data) 86 | assert.Equal(t, tc.expect, res) 87 | if tc.err != nil { 88 | assert.Equal(t, tc.err, err) 89 | } else { 90 | assert.NoError(t, err) 91 | } 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /internal/wire/tuple_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAppendTuple(t *testing.T) { 12 | cases := []struct { 13 | t Tuple 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | t: nil, 19 | buf: []byte{}, 20 | expect: []byte{0x00}, 21 | }, 22 | { 23 | t: []string{"A"}, 24 | buf: []byte{}, 25 | expect: []byte{0x01, 0x01, 'A'}, 26 | }, 27 | { 28 | t: []string{"A", "ABC"}, 29 | buf: []byte{}, 30 | expect: []byte{0x02, 0x01, 'A', 0x03, 'A', 'B', 'C'}, 31 | }, 32 | } 33 | for i, tc := range cases { 34 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 35 | res := tc.t.append(tc.buf) 36 | assert.Equal(t, tc.expect, res) 37 | }) 38 | } 39 | } 40 | 41 | func TestParseTuple(t *testing.T) { 42 | cases := []struct { 43 | data []byte 44 | expect Tuple 45 | err error 46 | n int 47 | }{ 48 | { 49 | data: []byte{}, 50 | expect: nil, 51 | err: io.EOF, 52 | n: 0, 53 | }, 54 | { 55 | data: []byte{0x02, 0x01, 'a'}, 56 | expect: []string{"a"}, 57 | err: io.EOF, 58 | n: 3, 59 | }, 60 | { 61 | data: []byte{0x00}, 62 | expect: []string{}, 63 | err: nil, 64 | n: 1, 65 | }, 66 | { 67 | data: []byte{0x01, 0x01, 'a'}, 68 | expect: []string{"a"}, 69 | err: nil, 70 | n: 3, 71 | }, 72 | { 73 | data: []byte{0x02, 0x01, 'a', 0x02, 'a', 'b'}, 74 | expect: []string{"a", "ab"}, 75 | err: nil, 76 | n: 6, 77 | }, 78 | } 79 | for i, tc := range cases { 80 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 81 | res, n, err := parseTuple(tc.data) 82 | if tc.err != nil { 83 | assert.Error(t, err) 84 | assert.Equal(t, tc.err, err) 85 | assert.Equal(t, tc.expect, res) 86 | assert.Equal(t, tc.n, n) 87 | return 88 | } 89 | assert.NoError(t, err) 90 | assert.Equal(t, tc.expect, res) 91 | assert.Equal(t, tc.n, n) 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /internal/wire/announce_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAnnounceMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | am AnnounceMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | am: AnnounceMessage{ 19 | RequestID: 0, 20 | TrackNamespace: []string{""}, 21 | Parameters: KVPList{}, 22 | }, 23 | buf: []byte{}, 24 | expect: []byte{ 25 | 0x00, 0x01, 0x00, 0x00, 26 | }, 27 | }, 28 | { 29 | am: AnnounceMessage{ 30 | RequestID: 1, 31 | TrackNamespace: []string{"tracknamespace"}, 32 | Parameters: KVPList{}, 33 | }, 34 | buf: []byte{0x0a, 0x0b}, 35 | expect: []byte{0x0a, 0x0b, 0x01, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x00}, 36 | }, 37 | } 38 | for i, tc := range cases { 39 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 40 | res := tc.am.Append(tc.buf) 41 | assert.Equal(t, tc.expect, res) 42 | }) 43 | } 44 | } 45 | 46 | func TestParseAnnounceMessage(t *testing.T) { 47 | cases := []struct { 48 | data []byte 49 | expect *AnnounceMessage 50 | err error 51 | }{ 52 | { 53 | data: nil, 54 | expect: &AnnounceMessage{}, 55 | err: io.EOF, 56 | }, 57 | { 58 | data: []byte{}, 59 | expect: &AnnounceMessage{}, 60 | err: io.EOF, 61 | }, 62 | { 63 | data: append(append([]byte{0x00, 0x01, 0x09}, "trackname"...), 0x00), 64 | expect: &AnnounceMessage{ 65 | RequestID: 0, 66 | TrackNamespace: []string{"trackname"}, 67 | Parameters: KVPList{}, 68 | }, 69 | err: nil, 70 | }, 71 | } 72 | for i, tc := range cases { 73 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 74 | res := &AnnounceMessage{} 75 | err := res.parse(CurrentVersion, tc.data) 76 | assert.Equal(t, tc.expect, res) 77 | if tc.err != nil { 78 | assert.Equal(t, tc.err, err) 79 | } else { 80 | assert.NoError(t, err) 81 | } 82 | }) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /internal/wire/subscribe_error_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSubscribeErrorMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | sem SubscribeErrorMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | sem: SubscribeErrorMessage{ 19 | RequestID: 0, 20 | ErrorCode: 0, 21 | ReasonPhrase: "", 22 | }, 23 | buf: []byte{0x0a, 0x0b}, 24 | expect: []byte{ 25 | 0x0a, 0x0b, 0x00, 0x00, 0x00, 26 | }, 27 | }, 28 | { 29 | sem: SubscribeErrorMessage{ 30 | RequestID: 17, 31 | ErrorCode: 12, 32 | ReasonPhrase: "reason", 33 | }, 34 | buf: []byte{}, 35 | expect: []byte{0x11, 0x0c, 0x06, 'r', 'e', 'a', 's', 'o', 'n'}, 36 | }, 37 | } 38 | for i, tc := range cases { 39 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 40 | res := tc.sem.Append(tc.buf) 41 | assert.Equal(t, tc.expect, res) 42 | }) 43 | } 44 | } 45 | 46 | func TestParseSubscribeErrorMessage(t *testing.T) { 47 | cases := []struct { 48 | data []byte 49 | expect *SubscribeErrorMessage 50 | err error 51 | }{ 52 | { 53 | data: nil, 54 | expect: &SubscribeErrorMessage{}, 55 | err: io.EOF, 56 | }, 57 | { 58 | data: []byte{0x01, 0x02, 0x03, 0x04}, 59 | expect: &SubscribeErrorMessage{ 60 | RequestID: 1, 61 | ErrorCode: 2, 62 | ReasonPhrase: "", 63 | }, 64 | err: io.ErrUnexpectedEOF, 65 | }, 66 | { 67 | data: []byte{0x00, 0x01, 0x05, 'e', 'r', 'r', 'o', 'r'}, 68 | expect: &SubscribeErrorMessage{ 69 | RequestID: 0, 70 | ErrorCode: 1, 71 | ReasonPhrase: "error", 72 | }, 73 | err: nil, 74 | }, 75 | } 76 | for i, tc := range cases { 77 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 78 | res := &SubscribeErrorMessage{} 79 | err := res.parse(CurrentVersion, tc.data) 80 | assert.Equal(t, tc.expect, res) 81 | if tc.err != nil { 82 | assert.Equal(t, tc.err, err) 83 | } else { 84 | assert.NoError(t, err) 85 | } 86 | }) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /mockgen.go: -------------------------------------------------------------------------------- 1 | //go:build gomock || generate 2 | 3 | package moqtransport 4 | 5 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_handler_test.go github.com/mengelbart/moqtransport Handler" 6 | 7 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_stream_test.go github.com/mengelbart/moqtransport Stream" 8 | 9 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_receive_stream_test.go github.com/mengelbart/moqtransport ReceiveStream" 10 | 11 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_send_stream_test.go github.com/mengelbart/moqtransport SendStream" 12 | 13 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_connection_test.go github.com/mengelbart/moqtransport Connection" 14 | 15 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_object_message_parser_test.go github.com/mengelbart/moqtransport ObjectMessageParser" 16 | type ObjectMessageParser = objectMessageParser 17 | 18 | //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_control_message_stream_test.go github.com/mengelbart/moqtransport ControlMessageStream" 19 | type ControlMessageStream = controlMessageStream 20 | -------------------------------------------------------------------------------- /internal/wire/track_status_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTrackStatusMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | tsm TrackStatusMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | tsm: TrackStatusMessage{ 19 | RequestID: 0, 20 | StatusCode: 0, 21 | LargestLocation: Location{ 22 | Group: 0, 23 | Object: 0, 24 | }, 25 | }, 26 | buf: []byte{}, 27 | expect: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, 28 | }, 29 | { 30 | tsm: TrackStatusMessage{ 31 | RequestID: 1, 32 | StatusCode: 2, 33 | LargestLocation: Location{ 34 | Group: 1, 35 | Object: 2, 36 | }, 37 | Parameters: KVPList{}, 38 | }, 39 | buf: []byte{0x0a, 0x0b}, 40 | expect: []byte{0x0a, 0x0b, 0x01, 0x02, 0x01, 0x02, 0x00}, 41 | }, 42 | } 43 | for i, tc := range cases { 44 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 45 | res := tc.tsm.Append(tc.buf) 46 | assert.Equal(t, tc.expect, res) 47 | }) 48 | } 49 | } 50 | 51 | func TestParseTrackStatusMessage(t *testing.T) { 52 | cases := []struct { 53 | data []byte 54 | expect *TrackStatusMessage 55 | err error 56 | }{ 57 | { 58 | data: nil, 59 | expect: &TrackStatusMessage{}, 60 | err: io.EOF, 61 | }, 62 | { 63 | data: []byte{}, 64 | expect: &TrackStatusMessage{}, 65 | err: io.EOF, 66 | }, 67 | { 68 | data: []byte{0x01, 0x02, 0x03, 0x04, 0x00}, 69 | expect: &TrackStatusMessage{ 70 | RequestID: 1, 71 | StatusCode: 2, 72 | LargestLocation: Location{ 73 | Group: 3, 74 | Object: 4, 75 | }, 76 | Parameters: KVPList{}, 77 | }, 78 | err: nil, 79 | }, 80 | } 81 | for i, tc := range cases { 82 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 83 | res := &TrackStatusMessage{} 84 | err := res.parse(CurrentVersion, tc.data) 85 | assert.Equal(t, tc.expect, res) 86 | if tc.err != nil { 87 | assert.Equal(t, tc.err, err) 88 | } else { 89 | assert.NoError(t, err) 90 | } 91 | }) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /internal/wire/token.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import "github.com/quic-go/quic-go/quicvarint" 4 | 5 | const ( 6 | TokenTypeDelete = 0x00 7 | TokenTypeRegister = 0x01 8 | TokenTypeUseAlias = 0x02 9 | TokenTypeUseValue = 0x03 10 | ) 11 | 12 | type Token struct { 13 | AliasType uint64 14 | Alias uint64 15 | Type uint64 16 | Value []byte 17 | } 18 | 19 | func (t Token) Append(buf []byte) []byte { 20 | buf = quicvarint.Append(buf, t.AliasType) 21 | switch t.AliasType { 22 | case TokenTypeDelete, TokenTypeUseAlias: 23 | buf = quicvarint.Append(buf, t.Alias) 24 | case TokenTypeRegister: 25 | buf = quicvarint.Append(buf, t.Alias) 26 | buf = quicvarint.Append(buf, t.Type) 27 | buf = append(buf, t.Value...) 28 | case TokenTypeUseValue: 29 | buf = quicvarint.Append(buf, t.Type) 30 | buf = append(buf, t.Value...) 31 | } 32 | return buf 33 | } 34 | 35 | func (t *Token) Parse(data []byte) (parsed int, err error) { 36 | var n int 37 | t.AliasType, n, err = quicvarint.Parse(data) 38 | parsed += n 39 | if err != nil { 40 | return parsed, err 41 | } 42 | data = data[n:] 43 | 44 | switch t.AliasType { 45 | case TokenTypeDelete, TokenTypeUseAlias: 46 | t.Alias, n, err = quicvarint.Parse(data) 47 | parsed += n 48 | if err != nil { 49 | return parsed, err 50 | } 51 | 52 | case TokenTypeRegister: 53 | t.Alias, n, err = quicvarint.Parse(data) 54 | parsed += n 55 | if err != nil { 56 | return parsed, err 57 | } 58 | data = data[n:] 59 | 60 | t.Type, n, err = quicvarint.Parse(data) 61 | parsed += n 62 | if err != nil { 63 | return parsed, err 64 | } 65 | data = data[n:] 66 | 67 | t.Value = make([]byte, len(data)) 68 | n = copy(t.Value, data) 69 | parsed += n 70 | if n != len(data) { 71 | return parsed, errLengthMismatch 72 | } 73 | 74 | case TokenTypeUseValue: 75 | t.Type, n, err = quicvarint.Parse(data) 76 | parsed += n 77 | if err != nil { 78 | return parsed, err 79 | } 80 | data = data[n:] 81 | 82 | t.Value = make([]byte, len(data)) 83 | n = copy(t.Value, data) 84 | parsed += n 85 | if n != len(data) { 86 | return parsed, errLengthMismatch 87 | } 88 | } 89 | return 90 | } 91 | -------------------------------------------------------------------------------- /internal/wire/version.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type Version uint64 10 | 11 | const ( 12 | Draft_ietf_moq_transport_00 Version = 0xff000000 13 | Draft_ietf_moq_transport_01 Version = 0xff000001 14 | Draft_ietf_moq_transport_02 Version = 0xff000002 15 | Draft_ietf_moq_transport_03 Version = 0xff000003 16 | Draft_ietf_moq_transport_04 Version = 0xff000004 17 | Draft_ietf_moq_transport_05 Version = 0xff000005 18 | Draft_ietf_moq_transport_06 Version = 0xff000006 19 | Draft_ietf_moq_transport_07 Version = 0xff000007 20 | Draft_ietf_moq_transport_08 Version = 0xff000008 21 | Draft_ietf_moq_transport_10 Version = 0xff00000a 22 | Draft_ietf_moq_transport_11 Version = 0xff00000b 23 | 24 | CurrentVersion = Draft_ietf_moq_transport_11 25 | ) 26 | 27 | var SupportedVersions = []Version{CurrentVersion} 28 | 29 | func (v Version) String() string { 30 | return fmt.Sprintf("0x%x", uint64(v)) 31 | } 32 | 33 | func (v Version) Len() uint64 { 34 | return uint64(quicvarint.Len(uint64(v))) 35 | } 36 | 37 | type versions []Version 38 | 39 | func (v versions) String() string { 40 | res := "[" 41 | for i, e := range v { 42 | if i < len(v)-1 { 43 | res += fmt.Sprintf("%v, ", e) 44 | } else { 45 | res += fmt.Sprintf("%v", e) 46 | } 47 | } 48 | res += "]" 49 | return res 50 | } 51 | 52 | func (v versions) Len() uint64 { 53 | l := uint64(0) 54 | for _, x := range v { 55 | l = l + x.Len() 56 | } 57 | return l 58 | } 59 | 60 | func (v versions) append(buf []byte) []byte { 61 | buf = quicvarint.Append(buf, uint64(len(v))) 62 | for _, vv := range v { 63 | buf = quicvarint.Append(buf, uint64(vv)) 64 | } 65 | return buf 66 | } 67 | 68 | func (vs *versions) parse(data []byte) (int, error) { 69 | numVersions, parsed, err := quicvarint.Parse(data) 70 | if err != nil { 71 | return parsed, err 72 | } 73 | data = data[parsed:] 74 | 75 | for i := 0; i < int(numVersions); i++ { 76 | v, n, err := quicvarint.Parse(data) 77 | parsed += n 78 | if err != nil { 79 | return parsed, err 80 | } 81 | data = data[n:] 82 | *vs = append(*vs, Version(v)) 83 | } 84 | return parsed, nil 85 | } 86 | -------------------------------------------------------------------------------- /internal/wire/unannounce_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestUnannounceMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | uam UnannounceMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | uam: UnannounceMessage{ 19 | TrackNamespace: []string{""}, 20 | }, 21 | buf: []byte{}, 22 | expect: []byte{ 23 | 0x01, 0x00, 24 | }, 25 | }, 26 | { 27 | uam: UnannounceMessage{ 28 | TrackNamespace: []string{"tracknamespace"}, 29 | }, 30 | buf: []byte{0x0a, 0x0b}, 31 | expect: []byte{0x0a, 0x0b, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e'}, 32 | }, 33 | } 34 | for i, tc := range cases { 35 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 36 | res := tc.uam.Append(tc.buf) 37 | assert.Equal(t, tc.expect, res) 38 | }) 39 | } 40 | } 41 | 42 | func TestParseUnannounceMessage(t *testing.T) { 43 | cases := []struct { 44 | data []byte 45 | expect *UnannounceMessage 46 | err error 47 | }{ 48 | { 49 | data: nil, 50 | expect: &UnannounceMessage{}, 51 | err: io.EOF, 52 | }, 53 | { 54 | data: append([]byte{0x01, 0x0E}, "tracknamespace"...), 55 | expect: &UnannounceMessage{ 56 | TrackNamespace: []string{"tracknamespace"}, 57 | }, 58 | err: nil, 59 | }, 60 | { 61 | data: append([]byte{0x01, 0x05}, "tracknamespace"...), 62 | expect: &UnannounceMessage{ 63 | TrackNamespace: []string{"track"}, 64 | }, 65 | err: nil, 66 | }, 67 | { 68 | data: append([]byte{0x01, 0x0F}, "tracknamespace"...), 69 | expect: &UnannounceMessage{ 70 | TrackNamespace: []string{}, 71 | }, 72 | err: errLengthMismatch, 73 | }, 74 | } 75 | for i, tc := range cases { 76 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 77 | res := &UnannounceMessage{} 78 | err := res.parse(CurrentVersion, tc.data) 79 | if tc.err != nil { 80 | assert.Equal(t, tc.err, err) 81 | assert.Equal(t, tc.expect, res) 82 | return 83 | } 84 | assert.NoError(t, err) 85 | assert.Equal(t, tc.expect, res) 86 | }) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Media over QUIC Transport (MoQT) 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/mengelbart/moqtransport.svg)](https://pkg.go.dev/github.com/mengelbart/moqtransport) 4 | 5 | `moqtransport` is a Go implementation of [Media over QUIC Transport](https://datatracker.ietf.org/doc/draft-ietf-moq-transport/) on top of [quic-go](https://github.com/quic-go/quic-go) and optionally [webtransport-go](https://github.com/quic-go/webtransport-go/). 6 | 7 | ## Overview 8 | 9 | This library implements the Media over QUIC Transport (MoQT) protocol as defined in [draft-ietf-moq-transport-11](https://www.ietf.org/archive/id/draft-ietf-moq-transport-11.txt). MoQT is designed to operate over QUIC or WebTransport for efficient media delivery with a publish/subscribe model. 10 | 11 | ### Implementation Status 12 | 13 | This code, as well as the specification, is work in progress. 14 | The implementation currently covers most aspects of the MoQT specification (draft-11), including: 15 | 16 | Session establishment and initialization 17 | Control message encoding and handling 18 | Data stream management 19 | Track announcement and subscription 20 | Error handling 21 | Support for both QUIC and WebTransport 22 | 23 | ### Areas for Future Development 24 | 25 | Implementation of FETCH 26 | Exposure of more parameters 27 | ... 28 | 29 | ## Usage 30 | 31 | See the [date examples in the examples directory](examples/date/README.md) for a simple demonstration of how to use this library. 32 | 33 | Basic usage involves: 34 | 35 | 1. Creating a connection using either QUIC or WebTransport 36 | 2. Establishing a MoQT session 37 | 3. Implementing handlers for various MoQT messages 38 | 4. Publishing or subscribing to tracks 39 | 40 | ## Project Structure 41 | 42 | - `quicmoq/`: QUIC-specific implementation 43 | - `webtransportmoq/`: WebTransport-specific implementation 44 | - `internal/`: Internal implementation details 45 | - `examples/`: Example applications demonstrating usage 46 | - `integrationtests/`: Integration tests 47 | 48 | ## Requirements 49 | 50 | - Go 1.23.6 or later 51 | - Dependencies are managed via Go modules 52 | 53 | ## License 54 | 55 | See the [LICENSE](LICENSE) file for details. 56 | -------------------------------------------------------------------------------- /internal/wire/version_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestVersionsLen(t *testing.T) { 12 | cases := []struct { 13 | versions versions 14 | expected uint64 15 | }{ 16 | { 17 | versions: []Version{}, 18 | expected: 0, 19 | }, 20 | { 21 | versions: []Version{Version(0)}, 22 | expected: 1, 23 | }, 24 | { 25 | versions: []Version{Version(CurrentVersion)}, 26 | expected: 8, 27 | }, 28 | { 29 | versions: []Version{Version(1024)}, 30 | expected: 2, 31 | }, 32 | } 33 | for i, tc := range cases { 34 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 35 | res := tc.versions.Len() 36 | assert.Equal(t, tc.expected, res) 37 | }) 38 | } 39 | } 40 | 41 | func TestVersionsAppend(t *testing.T) { 42 | cases := []struct { 43 | versions versions 44 | buf []byte 45 | expected []byte 46 | }{ 47 | { 48 | versions: []Version{}, 49 | buf: []byte{}, 50 | expected: []byte{0x00}, 51 | }, 52 | { 53 | versions: []Version{0}, 54 | buf: []byte{}, 55 | expected: []byte{0x01, 0x00}, 56 | }, 57 | } 58 | for i, tc := range cases { 59 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 60 | res := tc.versions.append(tc.buf) 61 | assert.Equal(t, tc.expected, res) 62 | }) 63 | } 64 | } 65 | 66 | func TestParseVersions(t *testing.T) { 67 | cases := []struct { 68 | data []byte 69 | expect versions 70 | err error 71 | n int 72 | }{ 73 | { 74 | data: nil, 75 | expect: versions{}, 76 | err: io.EOF, 77 | n: 0, 78 | }, 79 | { 80 | data: []byte{0x01, 0x00}, 81 | expect: versions{0}, 82 | err: nil, 83 | n: 2, 84 | }, 85 | } 86 | for i, tc := range cases { 87 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 88 | res := versions{} 89 | n, err := res.parse(tc.data) 90 | assert.Equal(t, tc.expect, res) 91 | assert.Equal(t, tc.n, n) 92 | if tc.err != nil { 93 | assert.Equal(t, tc.err, err) 94 | } else { 95 | assert.NoError(t, err) 96 | } 97 | }) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /fetch_stream.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport/internal/wire" 5 | "github.com/mengelbart/qlog" 6 | "github.com/mengelbart/qlog/moqt" 7 | ) 8 | 9 | type FetchStream struct { 10 | stream SendStream 11 | qlogger *qlog.Logger 12 | } 13 | 14 | func newFetchStream(stream SendStream, requestID uint64, qlogger *qlog.Logger) (*FetchStream, error) { 15 | fhm := &wire.FetchHeaderMessage{ 16 | RequestID: requestID, 17 | } 18 | buf := make([]byte, 0, 24) 19 | buf = fhm.Append(buf) 20 | _, err := stream.Write(buf) 21 | if err != nil { 22 | return nil, err 23 | } 24 | if qlogger != nil { 25 | qlogger.Log(moqt.StreamTypeSetEvent{ 26 | Owner: moqt.GetOwner(moqt.OwnerLocal), 27 | StreamID: stream.StreamID(), 28 | StreamType: moqt.StreamTypeFetchHeader, 29 | }) 30 | } 31 | return &FetchStream{ 32 | stream: stream, 33 | qlogger: qlogger, 34 | }, nil 35 | } 36 | 37 | func (f *FetchStream) WriteObject( 38 | groupID, subgroupID, objectID uint64, 39 | priority uint8, 40 | payload []byte, 41 | ) (int, error) { 42 | buf := make([]byte, 0, 1400) 43 | fo := wire.ObjectMessage{ 44 | GroupID: groupID, 45 | SubgroupID: subgroupID, 46 | ObjectID: objectID, 47 | PublisherPriority: priority, 48 | ObjectStatus: 0, 49 | ObjectPayload: payload, 50 | } 51 | buf = fo.AppendFetch(buf) 52 | _, err := f.stream.Write(buf) 53 | if err != nil { 54 | return 0, err 55 | } 56 | if f.qlogger != nil { 57 | f.qlogger.Log(moqt.FetchObjectEvent{ 58 | EventName: moqt.FetchObjectEventCreated, 59 | StreamID: f.stream.StreamID(), 60 | GroupID: groupID, 61 | SubgroupID: subgroupID, 62 | ObjectID: objectID, 63 | ExtensionHeadersLength: 0, 64 | ExtensionHeaders: nil, 65 | ObjectPayloadLength: uint64(len(payload)), 66 | ObjectStatus: 0, 67 | ObjectPayload: qlog.RawInfo{ 68 | Length: uint64(len(payload)), 69 | PayloadLength: uint64(len(payload)), 70 | Data: payload, 71 | }, 72 | }) 73 | } 74 | return len(payload), nil 75 | } 76 | 77 | func (f *FetchStream) Close() error { 78 | return f.stream.Close() 79 | } 80 | -------------------------------------------------------------------------------- /internal/wire/subscribe_update_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type SubscribeUpdateMessage struct { 10 | RequestID uint64 11 | StartLocation Location 12 | EndGroup uint64 13 | SubscriberPriority uint8 14 | Forward uint8 15 | Parameters KVPList 16 | } 17 | 18 | func (m *SubscribeUpdateMessage) LogValue() slog.Value { 19 | attrs := []slog.Attr{ 20 | slog.String("type", "subscribe_update"), 21 | slog.Uint64("request_id", m.RequestID), 22 | slog.Uint64("start_group", m.StartLocation.Group), 23 | slog.Uint64("start_object", m.StartLocation.Object), 24 | slog.Uint64("end_group", m.EndGroup), 25 | slog.Any("subscriber_priority", m.SubscriberPriority), 26 | slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), 27 | } 28 | if len(m.Parameters) > 0 { 29 | attrs = append(attrs, 30 | slog.Any("setup_parameters", m.Parameters), 31 | ) 32 | } 33 | return slog.GroupValue(attrs...) 34 | } 35 | 36 | func (m SubscribeUpdateMessage) Type() controlMessageType { 37 | return messageTypeSubscribeUpdate 38 | } 39 | 40 | func (m *SubscribeUpdateMessage) Append(buf []byte) []byte { 41 | buf = quicvarint.Append(buf, m.RequestID) 42 | buf = m.StartLocation.append(buf) 43 | buf = quicvarint.Append(buf, m.EndGroup) 44 | buf = append(buf, m.SubscriberPriority) 45 | buf = append(buf, m.Forward) 46 | return m.Parameters.appendNum(buf) 47 | } 48 | 49 | func (m *SubscribeUpdateMessage) parse(v Version, data []byte) (err error) { 50 | var n int 51 | 52 | m.RequestID, n, err = quicvarint.Parse(data) 53 | if err != nil { 54 | return err 55 | } 56 | data = data[n:] 57 | 58 | n, err = m.StartLocation.parse(v, data) 59 | if err != nil { 60 | return err 61 | } 62 | data = data[n:] 63 | 64 | m.EndGroup, n, err = quicvarint.Parse(data) 65 | if err != nil { 66 | return err 67 | } 68 | data = data[n:] 69 | 70 | if len(data) < 2 { 71 | return errLengthMismatch 72 | } 73 | m.SubscriberPriority = data[0] 74 | m.Forward = data[1] 75 | if m.Forward > 2 { 76 | return errInvalidForwardFlag 77 | } 78 | data = data[2:] 79 | 80 | m.Parameters = KVPList{} 81 | return m.Parameters.parseNum(data) 82 | } 83 | -------------------------------------------------------------------------------- /internal/wire/announce_error_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAnnounceErrorMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | aem AnnounceErrorMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | aem: AnnounceErrorMessage{ 19 | RequestID: 0, 20 | ErrorCode: 0, 21 | ReasonPhrase: "", 22 | }, 23 | buf: []byte{}, 24 | expect: []byte{ 25 | 0x00, 0x00, 0x00, 26 | }, 27 | }, 28 | { 29 | aem: AnnounceErrorMessage{ 30 | RequestID: 1, 31 | ErrorCode: 1, 32 | ReasonPhrase: "reason", 33 | }, 34 | buf: []byte{}, 35 | expect: append([]byte{0x01, 0x01, 0x06}, "reason"...), 36 | }, 37 | { 38 | aem: AnnounceErrorMessage{ 39 | RequestID: 1, 40 | ErrorCode: 1, 41 | ReasonPhrase: "reason", 42 | }, 43 | buf: []byte{0x0a, 0x0b, 0x0c, 0x0d}, 44 | expect: append([]byte{0x0a, 0x0b, 0x0c, 0x0d, 0x01, 0x01, 0x06}, "reason"...), 45 | }, 46 | } 47 | for i, tc := range cases { 48 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 49 | res := tc.aem.Append(tc.buf) 50 | assert.Equal(t, tc.expect, res) 51 | }) 52 | } 53 | } 54 | 55 | func TestParseAnnounceErrorMessage(t *testing.T) { 56 | cases := []struct { 57 | data []byte 58 | expect *AnnounceErrorMessage 59 | err error 60 | }{ 61 | { 62 | data: nil, 63 | expect: &AnnounceErrorMessage{}, 64 | err: io.EOF, 65 | }, 66 | { 67 | data: []byte{0x01, 0x03, 0x03, 'e', 'r'}, 68 | expect: &AnnounceErrorMessage{ 69 | RequestID: 1, 70 | ErrorCode: 3, 71 | ReasonPhrase: "", 72 | }, 73 | err: io.ErrUnexpectedEOF, 74 | }, 75 | { 76 | data: append([]byte{0x00, 0x01, 0x0d}, "reason phrase"...), 77 | expect: &AnnounceErrorMessage{ 78 | RequestID: 0, 79 | ErrorCode: 1, 80 | ReasonPhrase: "reason phrase", 81 | }, 82 | err: nil, 83 | }, 84 | } 85 | for i, tc := range cases { 86 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 87 | res := &AnnounceErrorMessage{} 88 | err := res.parse(CurrentVersion, tc.data) 89 | if tc.err != nil { 90 | assert.Equal(t, tc.err, err) 91 | assert.Equal(t, tc.expect, res) 92 | return 93 | } 94 | assert.NoError(t, err) 95 | assert.Equal(t, tc.expect, res) 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /examples/date/README.md: -------------------------------------------------------------------------------- 1 | # Example: Date server and client 2 | 3 | The `examples` directory contains an implementation of a simple client and 4 | server which can publish and subscribe to `date` tracks using MoQ. The publisher 5 | of a `date` track publishes timestamps every second and sends them to the 6 | subscribers. Subscribers receive the timestamp objects. 7 | 8 | Both the client and the server can take the roles of subscriber and publisher. 9 | The server need certificates and key that the client trusts. There is 10 | an automatic setup for localhost certificates and keys, unless proper 11 | ones are provided. 12 | 13 | Download the example: 14 | 15 | ```shell 16 | git clone https://github.com/mengelbart/moqtransport.git 17 | cd moqtransport/examples/date 18 | ``` 19 | 20 | To run the server and publish, run: 21 | 22 | ```shell 23 | go run . -server -publish 24 | ``` 25 | 26 | Then, open a new shell and start a client to subscribe: 27 | 28 | ```shell 29 | go run . -subscribe 30 | ``` 31 | 32 | Alternatively, let the server subscribe to the `date` track from the client: 33 | 34 | ```shell 35 | go run . -server -subscribe 36 | go run . -publish 37 | ``` 38 | 39 | (again in different shells). 40 | 41 | The server is always prepared to connect via webtransport at the end point `/moq`. 42 | 43 | To make a WebTransport connection from the client, run: 44 | 45 | ```shell 46 | go run . -subscribe -webtransport -addr https://localhost:8080/moq 47 | ``` 48 | 49 | The following sequence diagram shows typical traffic between the client and the server when the server is the publisher: 50 | 51 | ```mermaid 52 | sequenceDiagram 53 | participant Publisher as server/Publisher 54 | participant Subscriber as client/Subscriber 55 | 56 | Publisher->>Publisher: Start server 57 | Publisher->>Publisher: Setup date track 58 | 59 | Subscriber->>Publisher: Connect 60 | 61 | Publisher->>Subscriber: Announce "clock" namespace 62 | Subscriber->>Publisher: Accept announcement 63 | 64 | Subscriber->>Publisher: Subscribe to "clock/second" track 65 | Publisher->>Subscriber: Accept subscription 66 | Publisher->>Publisher: Add subscriber to publishers list 67 | 68 | loop Every second 69 | Publisher->>Publisher: Create new timestamp 70 | Publisher->>Subscriber: Send timestamp object 71 | Subscriber->>Subscriber: Process and display timestamp 72 | end 73 | ``` 74 | 75 | -------------------------------------------------------------------------------- /internal/wire/publish_ok_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type PublishOkMessage struct { 10 | RequestID uint64 11 | Forward uint8 12 | SubscriberPriority uint8 13 | GroupOrder uint8 14 | FilterType FilterType 15 | Start Location 16 | EndGroup uint64 17 | Parameters KVPList 18 | } 19 | 20 | func (m *PublishOkMessage) LogValue() slog.Value { 21 | return slog.GroupValue() 22 | } 23 | 24 | func (m *PublishOkMessage) Type() controlMessageType { 25 | return messageTypePublish 26 | } 27 | 28 | func (m *PublishOkMessage) Append(buf []byte) []byte { 29 | buf = quicvarint.Append(buf, m.RequestID) 30 | buf = append(buf, m.Forward) 31 | buf = append(buf, m.SubscriberPriority) 32 | buf = append(buf, m.GroupOrder) 33 | buf = m.FilterType.append(buf) 34 | if m.FilterType == FilterTypeAbsoluteStart || m.FilterType == FilterTypeAbsoluteRange { 35 | buf = m.Start.append(buf) 36 | } 37 | if m.FilterType == FilterTypeAbsoluteRange { 38 | buf = quicvarint.Append(buf, m.EndGroup) 39 | } 40 | return m.Parameters.append(buf) 41 | } 42 | 43 | func (m *PublishOkMessage) parse(v Version, data []byte) (err error) { 44 | var n int 45 | m.RequestID, n, err = quicvarint.Parse(data) 46 | if err != nil { 47 | return err 48 | } 49 | data = data[n:] 50 | 51 | if len(data) < 3 { 52 | return errLengthMismatch 53 | } 54 | m.Forward = data[0] 55 | if m.Forward > 1 { 56 | return errInvalidForwardFlag 57 | } 58 | m.SubscriberPriority = data[1] 59 | m.GroupOrder = data[2] 60 | if m.GroupOrder > 2 { 61 | return errInvalidGroupOrder 62 | } 63 | data = data[3:] 64 | 65 | filterType, n, err := quicvarint.Parse(data) 66 | if err != nil { 67 | return err 68 | } 69 | m.FilterType = FilterType(filterType) 70 | if m.FilterType == 0 || m.FilterType > 4 { 71 | return errInvalidFilterType 72 | } 73 | data = data[n:] 74 | 75 | if m.FilterType == FilterTypeAbsoluteStart || m.FilterType == FilterTypeAbsoluteRange { 76 | n, err = m.Start.parse(v, data) 77 | if err != nil { 78 | return err 79 | } 80 | data = data[n:] 81 | } 82 | 83 | if m.FilterType == FilterTypeAbsoluteRange { 84 | m.EndGroup, n, err = quicvarint.Parse(data) 85 | if err != nil { 86 | return err 87 | } 88 | data = data[n:] 89 | } 90 | 91 | m.Parameters = KVPList{} 92 | return m.Parameters.parseNum(data) 93 | } 94 | -------------------------------------------------------------------------------- /internal/wire/publish_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | type PublishMessage struct { 10 | RequestID uint64 11 | TrackNamespace Tuple 12 | TrackName []byte 13 | TrackAlias uint64 14 | GroupOrder uint8 15 | ContentExists uint8 16 | LargestLocation Location 17 | Forward uint8 18 | Parameters KVPList 19 | } 20 | 21 | func (m *PublishMessage) LogValue() slog.Value { 22 | return slog.GroupValue() 23 | } 24 | 25 | func (m *PublishMessage) Type() controlMessageType { 26 | return messageTypePublish 27 | } 28 | 29 | func (m *PublishMessage) Append(buf []byte) []byte { 30 | buf = quicvarint.Append(buf, m.RequestID) 31 | buf = m.TrackNamespace.append(buf) 32 | buf = appendVarIntBytes(buf, m.TrackName) 33 | buf = quicvarint.Append(buf, m.TrackAlias) 34 | buf = append(buf, m.GroupOrder) 35 | buf = append(buf, m.ContentExists) 36 | if m.ContentExists > 0 { 37 | buf = m.LargestLocation.append(buf) 38 | } 39 | buf = append(buf, m.Forward) 40 | return m.Parameters.appendNum(buf) 41 | } 42 | 43 | func (m *PublishMessage) parse(v Version, data []byte) (err error) { 44 | var n int 45 | m.RequestID, n, err = quicvarint.Parse(data) 46 | if err != nil { 47 | return err 48 | } 49 | data = data[n:] 50 | 51 | m.TrackNamespace, n, err = parseTuple(data) 52 | if err != nil { 53 | return err 54 | } 55 | data = data[n:] 56 | 57 | m.TrackName, n, err = parseVarIntBytes(data) 58 | if err != nil { 59 | return err 60 | } 61 | data = data[n:] 62 | 63 | m.TrackAlias, n, err = quicvarint.Parse(data) 64 | if err != nil { 65 | return 66 | } 67 | data = data[n:] 68 | 69 | if len(data) < 2 { 70 | return errLengthMismatch 71 | } 72 | m.GroupOrder = data[0] 73 | if m.GroupOrder > 2 { 74 | return errInvalidGroupOrder 75 | } 76 | m.ContentExists = data[1] 77 | if m.ContentExists > 1 { 78 | return errInvalidContentExistsByte 79 | } 80 | data = data[2:] 81 | 82 | if m.ContentExists == 1 { 83 | n, err = m.LargestLocation.parse(v, data) 84 | if err != nil { 85 | return err 86 | } 87 | data = data[n:] 88 | } 89 | 90 | if len(data) < 1 { 91 | return errLengthMismatch 92 | } 93 | m.Forward = data[1] 94 | if m.Forward > 1 { 95 | return errInvalidForwardFlag 96 | } 97 | m.Parameters = KVPList{} 98 | return m.Parameters.parseNum(data) 99 | } 100 | -------------------------------------------------------------------------------- /mock_handler_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/mengelbart/moqtransport (interfaces: Handler) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -build_flags=-tags=gomock -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_handler_test.go github.com/mengelbart/moqtransport Handler 7 | // 8 | 9 | package moqtransport 10 | 11 | import ( 12 | reflect "reflect" 13 | 14 | gomock "go.uber.org/mock/gomock" 15 | ) 16 | 17 | // MockHandler is a mock of Handler interface. 18 | type MockHandler struct { 19 | ctrl *gomock.Controller 20 | recorder *MockHandlerMockRecorder 21 | isgomock struct{} 22 | } 23 | 24 | // MockHandlerMockRecorder is the mock recorder for MockHandler. 25 | type MockHandlerMockRecorder struct { 26 | mock *MockHandler 27 | } 28 | 29 | // NewMockHandler creates a new mock instance. 30 | func NewMockHandler(ctrl *gomock.Controller) *MockHandler { 31 | mock := &MockHandler{ctrl: ctrl} 32 | mock.recorder = &MockHandlerMockRecorder{mock} 33 | return mock 34 | } 35 | 36 | // EXPECT returns an object that allows the caller to indicate expected use. 37 | func (m *MockHandler) EXPECT() *MockHandlerMockRecorder { 38 | return m.recorder 39 | } 40 | 41 | // Handle mocks base method. 42 | func (m *MockHandler) Handle(arg0 ResponseWriter, arg1 *Message) { 43 | m.ctrl.T.Helper() 44 | m.ctrl.Call(m, "Handle", arg0, arg1) 45 | } 46 | 47 | // Handle indicates an expected call of Handle. 48 | func (mr *MockHandlerMockRecorder) Handle(arg0, arg1 any) *MockHandlerHandleCall { 49 | mr.mock.ctrl.T.Helper() 50 | call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handle", reflect.TypeOf((*MockHandler)(nil).Handle), arg0, arg1) 51 | return &MockHandlerHandleCall{Call: call} 52 | } 53 | 54 | // MockHandlerHandleCall wrap *gomock.Call 55 | type MockHandlerHandleCall struct { 56 | *gomock.Call 57 | } 58 | 59 | // Return rewrite *gomock.Call.Return 60 | func (c *MockHandlerHandleCall) Return() *MockHandlerHandleCall { 61 | c.Call = c.Call.Return() 62 | return c 63 | } 64 | 65 | // Do rewrite *gomock.Call.Do 66 | func (c *MockHandlerHandleCall) Do(f func(ResponseWriter, *Message)) *MockHandlerHandleCall { 67 | c.Call = c.Call.Do(f) 68 | return c 69 | } 70 | 71 | // DoAndReturn rewrite *gomock.Call.DoAndReturn 72 | func (c *MockHandlerHandleCall) DoAndReturn(f func(ResponseWriter, *Message)) *MockHandlerHandleCall { 73 | c.Call = c.Call.DoAndReturn(f) 74 | return c 75 | } 76 | -------------------------------------------------------------------------------- /internal/wire/stream_header_subgroup_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestStreamHeaderSubgroupMessageAppend(t *testing.T) { 14 | cases := []struct { 15 | shgm SubgroupHeaderMessage 16 | buf []byte 17 | expect []byte 18 | }{ 19 | { 20 | shgm: SubgroupHeaderMessage{ 21 | TrackAlias: 0, 22 | GroupID: 0, 23 | SubgroupID: 0, 24 | PublisherPriority: 0, 25 | }, 26 | buf: []byte{}, 27 | expect: []byte{byte(StreamTypeSubgroupSIDExt), 0x00, 0x00, 0x00, 0x00}, 28 | }, 29 | { 30 | shgm: SubgroupHeaderMessage{ 31 | TrackAlias: 1, 32 | GroupID: 2, 33 | SubgroupID: 3, 34 | PublisherPriority: 4, 35 | }, 36 | buf: []byte{}, 37 | expect: []byte{byte(StreamTypeSubgroupSIDExt), 0x01, 0x02, 0x03, 0x04}, 38 | }, 39 | { 40 | shgm: SubgroupHeaderMessage{ 41 | TrackAlias: 1, 42 | GroupID: 2, 43 | SubgroupID: 3, 44 | PublisherPriority: 4, 45 | }, 46 | buf: []byte{0x0a, 0x0b}, 47 | expect: []byte{0x0a, 0x0b, byte(StreamTypeSubgroupSIDExt), 0x01, 0x02, 0x03, 0x04}, 48 | }, 49 | } 50 | for i, tc := range cases { 51 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 52 | res := tc.shgm.Append(tc.buf) 53 | assert.Equal(t, tc.expect, res) 54 | }) 55 | } 56 | } 57 | 58 | func TestParseStreamHeaderSubgroupMessage(t *testing.T) { 59 | cases := []struct { 60 | data []byte 61 | expect *SubgroupHeaderMessage 62 | err error 63 | }{ 64 | { 65 | data: nil, 66 | expect: &SubgroupHeaderMessage{}, 67 | err: io.EOF, 68 | }, 69 | { 70 | data: []byte{}, 71 | expect: &SubgroupHeaderMessage{}, 72 | err: io.EOF, 73 | }, 74 | { 75 | data: []byte{0x01, 0x02, 0x03, 0x04}, 76 | expect: &SubgroupHeaderMessage{ 77 | TrackAlias: 1, 78 | GroupID: 2, 79 | SubgroupID: 3, 80 | PublisherPriority: 4, 81 | }, 82 | err: nil, 83 | }, 84 | } 85 | for i, tc := range cases { 86 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 87 | reader := bufio.NewReader(bytes.NewReader(tc.data)) 88 | res := &SubgroupHeaderMessage{} 89 | err := res.parse(reader, true) 90 | assert.Equal(t, tc.expect, res) 91 | if tc.err != nil { 92 | assert.Equal(t, tc.err, err) 93 | } else { 94 | assert.NoError(t, err) 95 | } 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /messages_test.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mengelbart/moqtransport/internal/wire" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestKVPList_ToWire(t *testing.T) { 11 | t.Run("ToWire creates independent copy", func(t *testing.T) { 12 | original := KVPList{ 13 | {Type: 1, ValueVarInt: 100}, 14 | {Type: 2, ValueBytes: []byte("test")}, 15 | } 16 | 17 | wireCopy := original.ToWire() 18 | 19 | // Modify the original 20 | original[0].ValueVarInt = 200 21 | original[1].ValueBytes[0] = 'X' 22 | 23 | // Wire copy should be unchanged 24 | assert.Equal(t, uint64(100), wireCopy[0].ValueVarInt) 25 | assert.Equal(t, []byte("test"), wireCopy[1].ValueBytes) 26 | }) 27 | 28 | t.Run("ToWire handles nil slice", func(t *testing.T) { 29 | var original KVPList 30 | wireCopy := original.ToWire() 31 | assert.Nil(t, wireCopy) 32 | }) 33 | 34 | t.Run("ToWire handles empty slice", func(t *testing.T) { 35 | original := KVPList{} 36 | wireCopy := original.ToWire() 37 | assert.NotNil(t, wireCopy) 38 | assert.Len(t, wireCopy, 0) 39 | }) 40 | } 41 | 42 | func TestFromWire(t *testing.T) { 43 | t.Run("FromWire creates independent copy", func(t *testing.T) { 44 | original := wire.KVPList{ 45 | {Type: 1, ValueVarInt: 100}, 46 | {Type: 2, ValueBytes: []byte("test")}, 47 | } 48 | 49 | kvpCopy := FromWire(original) 50 | 51 | // Modify the original 52 | original[0].ValueVarInt = 200 53 | original[1].ValueBytes[0] = 'X' 54 | 55 | // KVP copy should be unchanged 56 | assert.Equal(t, uint64(100), kvpCopy[0].ValueVarInt) 57 | assert.Equal(t, []byte("test"), kvpCopy[1].ValueBytes) 58 | }) 59 | 60 | t.Run("FromWire handles nil slice", func(t *testing.T) { 61 | var original wire.KVPList 62 | kvpCopy := FromWire(original) 63 | assert.Nil(t, kvpCopy) 64 | }) 65 | 66 | t.Run("FromWire handles empty slice", func(t *testing.T) { 67 | original := wire.KVPList{} 68 | kvpCopy := FromWire(original) 69 | assert.NotNil(t, kvpCopy) 70 | assert.Len(t, kvpCopy, 0) 71 | }) 72 | } 73 | 74 | func TestKVPList_RoundTrip(t *testing.T) { 75 | t.Run("ToWire and FromWire preserve data", func(t *testing.T) { 76 | original := KVPList{ 77 | {Type: 1, ValueVarInt: 100}, 78 | {Type: 2, ValueBytes: []byte("test")}, 79 | {Type: 3, ValueVarInt: 42, ValueBytes: []byte("both")}, 80 | } 81 | 82 | // Convert to wire and back 83 | roundTrip := FromWire(original.ToWire()) 84 | 85 | assert.Equal(t, original, roundTrip) 86 | }) 87 | } -------------------------------------------------------------------------------- /subgroup.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "github.com/mengelbart/moqtransport/internal/wire" 5 | "github.com/mengelbart/qlog" 6 | "github.com/mengelbart/qlog/moqt" 7 | ) 8 | 9 | type Subgroup struct { 10 | qlogger *qlog.Logger 11 | 12 | stream SendStream 13 | groupID uint64 14 | subgroupID uint64 15 | } 16 | 17 | func newSubgroup(stream SendStream, trackAlias, groupID, subgroupID uint64, publisherPriority uint8, qlogger *qlog.Logger) (*Subgroup, error) { 18 | shgm := &wire.SubgroupHeaderMessage{ 19 | TrackAlias: trackAlias, 20 | GroupID: groupID, 21 | SubgroupID: subgroupID, 22 | PublisherPriority: publisherPriority, 23 | } 24 | buf := make([]byte, 0, 40) 25 | buf = shgm.Append(buf) 26 | _, err := stream.Write(buf) 27 | if err != nil { 28 | return nil, err 29 | } 30 | if qlogger != nil { 31 | qlogger.Log(moqt.StreamTypeSetEvent{ 32 | Owner: moqt.GetOwner(moqt.OwnerLocal), 33 | StreamID: stream.StreamID(), 34 | StreamType: moqt.StreamTypeSubgroupHeader, 35 | }) 36 | } 37 | return &Subgroup{ 38 | qlogger: qlogger, 39 | stream: stream, 40 | groupID: groupID, 41 | subgroupID: subgroupID, 42 | }, nil 43 | } 44 | 45 | func (s *Subgroup) WriteObject(objectID uint64, payload []byte) (int, error) { 46 | var buf []byte 47 | if len(payload) > 0 { 48 | buf = make([]byte, 0, 16+len(payload)) 49 | } else { 50 | buf = make([]byte, 0, 24) 51 | } 52 | o := wire.ObjectMessage{ 53 | ObjectID: objectID, 54 | ObjectPayload: payload, 55 | } 56 | buf = o.AppendSubgroup(buf) 57 | _, err := s.stream.Write(buf) 58 | if err != nil { 59 | return 0, err 60 | } 61 | if s.qlogger != nil { 62 | gid := new(uint64) 63 | sid := new(uint64) 64 | *gid = s.groupID 65 | *sid = s.subgroupID 66 | s.qlogger.Log(moqt.SubgroupObjectEvent{ 67 | EventName: moqt.SubgroupObjectEventCreated, 68 | StreamID: s.stream.StreamID(), 69 | GroupID: gid, 70 | SubgroupID: sid, 71 | ObjectID: objectID, 72 | ExtensionHeadersLength: 0, 73 | ExtensionHeaders: nil, 74 | ObjectPayloadLength: uint64(len(payload)), 75 | ObjectStatus: 0, 76 | ObjectPayload: qlog.RawInfo{ 77 | Length: uint64(len(payload)), 78 | PayloadLength: uint64(len(payload)), 79 | Data: payload, 80 | }, 81 | }) 82 | } 83 | return len(payload), nil 84 | } 85 | 86 | // Close closes the subgroup. 87 | func (s *Subgroup) Close() error { 88 | return s.stream.Close() 89 | } 90 | -------------------------------------------------------------------------------- /internal/wire/track_status_request_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTrackStatusRequestMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | aom TrackStatusRequestMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | aom: TrackStatusRequestMessage{ 19 | RequestID: 0, 20 | TrackNamespace: []string{""}, 21 | TrackName: []byte(""), 22 | Parameters: KVPList{}, 23 | }, 24 | buf: []byte{}, 25 | expect: []byte{ 26 | 0x00, 0x01, 0x00, 0x00, 0x00, 27 | }, 28 | }, 29 | { 30 | aom: TrackStatusRequestMessage{ 31 | RequestID: 0, 32 | TrackNamespace: []string{"tracknamespace"}, 33 | TrackName: []byte("track"), 34 | Parameters: KVPList{}, 35 | }, 36 | buf: []byte{0x0a, 0x0b}, 37 | expect: []byte{0x0a, 0x0b, 0x00, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x05, 't', 'r', 'a', 'c', 'k', 0x00}, 38 | }, 39 | } 40 | for i, tc := range cases { 41 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 42 | res := tc.aom.Append(tc.buf) 43 | assert.Equal(t, tc.expect, res) 44 | }) 45 | } 46 | } 47 | 48 | func TestParseTrackStatusRequestMessage(t *testing.T) { 49 | cases := []struct { 50 | data []byte 51 | expect *TrackStatusRequestMessage 52 | err error 53 | }{ 54 | { 55 | data: nil, 56 | expect: &TrackStatusRequestMessage{}, 57 | err: io.EOF, 58 | }, 59 | { 60 | data: []byte{0x00, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x05, 't', 'r', 'a', 'c', 'k', 0x00}, 61 | expect: &TrackStatusRequestMessage{ 62 | RequestID: 0, 63 | TrackNamespace: []string{"tracknamespace"}, 64 | TrackName: []byte("track"), 65 | Parameters: KVPList{}, 66 | }, 67 | err: nil, 68 | }, 69 | { 70 | data: append([]byte{0x00, 0x10}, append([]byte("tracknamespace"), 0x00)...), 71 | expect: &TrackStatusRequestMessage{ 72 | RequestID: 0, 73 | TrackNamespace: []string{}, 74 | TrackName: nil, 75 | Parameters: nil, 76 | }, 77 | err: errLengthMismatch, 78 | }, 79 | } 80 | for i, tc := range cases { 81 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 82 | res := &TrackStatusRequestMessage{} 83 | err := res.parse(CurrentVersion, tc.data) 84 | assert.Equal(t, tc.expect, res) 85 | if tc.err != nil { 86 | assert.Equal(t, tc.err, err) 87 | } else { 88 | assert.NoError(t, err) 89 | } 90 | }) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /internal/wire/key_value_pair.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/quic-go/quic-go/quicvarint" 9 | ) 10 | 11 | type KeyValuePair struct { 12 | Type uint64 13 | ValueBytes []byte 14 | ValueVarInt uint64 15 | } 16 | 17 | func (p *KeyValuePair) String() string { 18 | if p.Type%2 == 1 { 19 | return fmt.Sprintf("{key: %v, value: '%v'}", p.Type, p.ValueBytes) 20 | } 21 | return fmt.Sprintf("{key: %v, value: '%v'}", p.Type, p.ValueVarInt) 22 | } 23 | 24 | func (p KeyValuePair) length() uint64 { 25 | length := uint64(quicvarint.Len(p.Type)) 26 | if p.Type%2 == 1 { 27 | length += uint64(quicvarint.Len(uint64(len(p.ValueBytes)))) 28 | length += uint64(len(p.ValueBytes)) 29 | return length 30 | } 31 | length += uint64(quicvarint.Len(p.ValueVarInt)) 32 | return length 33 | } 34 | 35 | func (p KeyValuePair) append(buf []byte) []byte { 36 | buf = quicvarint.Append(buf, p.Type) 37 | if p.Type%2 == 1 { 38 | buf = quicvarint.Append(buf, uint64(len(p.ValueBytes))) 39 | return append(buf, p.ValueBytes...) 40 | } 41 | return quicvarint.Append(buf, p.ValueVarInt) 42 | } 43 | 44 | func (p *KeyValuePair) parse(data []byte) (int, error) { 45 | var n, parsed int 46 | var err error 47 | p.Type, n, err = quicvarint.Parse(data) 48 | parsed += n 49 | if err != nil { 50 | return n, err 51 | } 52 | data = data[n:] 53 | 54 | if p.Type%2 == 1 { 55 | var length uint64 56 | length, n, err = quicvarint.Parse(data) 57 | parsed += n 58 | if err != nil { 59 | return parsed, err 60 | } 61 | data = data[n:] 62 | p.ValueBytes = make([]byte, length) // TODO: Don't allocate memory here? 63 | m := copy(p.ValueBytes, data[:length]) 64 | parsed += m 65 | if uint64(m) != length { 66 | return parsed, errLengthMismatch 67 | } 68 | return parsed, nil 69 | } 70 | 71 | p.ValueVarInt, n, err = quicvarint.Parse(data) 72 | parsed += n 73 | return parsed, err 74 | } 75 | 76 | func (p *KeyValuePair) parseReader(br *bufio.Reader) error { 77 | var err error 78 | p.Type, err = quicvarint.Read(br) 79 | if err != nil { 80 | return err 81 | } 82 | if p.Type%2 == 1 { 83 | var length uint64 84 | length, err = quicvarint.Read(br) 85 | if err != nil { 86 | return err 87 | } 88 | p.ValueBytes = make([]byte, length) 89 | var m int 90 | m, err = io.ReadFull(br, p.ValueBytes) 91 | if err != nil { 92 | return err 93 | } 94 | if uint64(m) != length { 95 | return errLengthMismatch 96 | } 97 | return nil 98 | } 99 | p.ValueVarInt, err = quicvarint.Read(br) 100 | return err 101 | } 102 | -------------------------------------------------------------------------------- /internal/wire/announce_cancel_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAnnounceCancelMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | aom AnnounceCancelMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | aom: AnnounceCancelMessage{ 19 | TrackNamespace: []string{""}, 20 | ErrorCode: 1, 21 | ReasonPhrase: "reason", 22 | }, 23 | buf: []byte{}, 24 | expect: []byte{ 25 | 0x01, 0x00, 0x01, 0x06, 'r', 'e', 'a', 's', 'o', 'n', 26 | }, 27 | }, 28 | { 29 | aom: AnnounceCancelMessage{ 30 | TrackNamespace: []string{"tracknamespace"}, 31 | ErrorCode: 1, 32 | ReasonPhrase: "reason", 33 | }, 34 | buf: []byte{0x0a, 0x0b}, 35 | expect: []byte{ 36 | 0x0a, 0x0b, 37 | 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 38 | 0x01, 39 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 40 | }, 41 | }, 42 | } 43 | for i, tc := range cases { 44 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 45 | res := tc.aom.Append(tc.buf) 46 | assert.Equal(t, tc.expect, res) 47 | }) 48 | } 49 | } 50 | 51 | func TestParseAnnounceCancelMessage(t *testing.T) { 52 | cases := []struct { 53 | data []byte 54 | expect *AnnounceCancelMessage 55 | err error 56 | }{ 57 | { 58 | data: nil, 59 | expect: &AnnounceCancelMessage{}, 60 | err: io.EOF, 61 | }, 62 | { 63 | data: append( 64 | []byte{0x01, 0x0E}, append([]byte("tracknamespace"), 0x00, 0x00)..., 65 | ), 66 | expect: &AnnounceCancelMessage{ 67 | TrackNamespace: []string{"tracknamespace"}, 68 | ErrorCode: 0, 69 | ReasonPhrase: "", 70 | }, 71 | err: nil, 72 | }, 73 | { 74 | data: append([]byte{0x01, 0x05}, append([]byte("track"), []byte{0x01, 0x06, 'r', 'e', 'a', 's', 'o', 'n', 'p', 'h', 'r', 'a', 's', 'e'}...)...), 75 | expect: &AnnounceCancelMessage{ 76 | TrackNamespace: []string{"track"}, 77 | ErrorCode: 1, 78 | ReasonPhrase: "reason", 79 | }, 80 | err: nil, 81 | }, 82 | { 83 | data: append([]byte{0x01, 0x0F}, "tracknamespace"...), 84 | expect: &AnnounceCancelMessage{ 85 | TrackNamespace: []string{}, 86 | }, 87 | err: errLengthMismatch, 88 | }, 89 | } 90 | for i, tc := range cases { 91 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 92 | res := &AnnounceCancelMessage{} 93 | err := res.parse(CurrentVersion, tc.data) 94 | assert.Equal(t, tc.expect, res) 95 | if tc.err != nil { 96 | assert.Equal(t, tc.err, err) 97 | } else { 98 | assert.NoError(t, err) 99 | } 100 | }) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /internal/wire/varint_bytes_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestAppendVarIntString(t *testing.T) { 13 | cases := []struct { 14 | buf []byte 15 | in string 16 | expect []byte 17 | }{ 18 | { 19 | buf: nil, 20 | in: "", 21 | expect: []byte{0x00}, 22 | }, 23 | { 24 | buf: []byte{0x01, 0x02, 0x03}, 25 | in: "", 26 | expect: []byte{0x01, 0x02, 0x03, 0x00}, 27 | }, 28 | { 29 | buf: []byte{}, 30 | in: "hello world", 31 | expect: append([]byte{0x0b}, []byte("hello world")...), 32 | }, 33 | { 34 | buf: []byte{0x01, 0x02, 0x03}, 35 | in: "hello world", 36 | expect: append([]byte{0x01, 0x02, 0x03, 0x0b}, []byte("hello world")...), 37 | }, 38 | } 39 | for i, tc := range cases { 40 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 41 | res := appendVarIntBytes(tc.buf, []byte(tc.in)) 42 | assert.Equal(t, tc.expect, res) 43 | }) 44 | } 45 | } 46 | 47 | func TestVarIntStringLen(t *testing.T) { 48 | cases := []struct { 49 | in string 50 | expect uint64 51 | }{ 52 | { 53 | in: "", 54 | expect: 1, 55 | }, 56 | { 57 | in: "hello world", 58 | expect: 1 + 11, 59 | }, 60 | { 61 | in: strings.Repeat("AAAAAAAA", 20), 62 | expect: 2 + 160, 63 | }, 64 | } 65 | for i, tc := range cases { 66 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 67 | res := varIntBytesLen(tc.in) 68 | assert.Equal(t, tc.expect, res) 69 | }) 70 | } 71 | } 72 | 73 | func TestParseVarIntBytes(t *testing.T) { 74 | cases := []struct { 75 | data []byte 76 | expect []byte 77 | err error 78 | n int 79 | }{ 80 | { 81 | data: nil, 82 | expect: []byte(""), 83 | err: io.EOF, 84 | n: 0, 85 | }, 86 | { 87 | data: []byte{}, 88 | expect: []byte(""), 89 | err: io.EOF, 90 | n: 0, 91 | }, 92 | { 93 | data: append([]byte{0x01}, "A"...), 94 | expect: []byte("A"), 95 | err: nil, 96 | n: 2, 97 | }, 98 | { 99 | data: append([]byte{0x04}, "ABC"...), 100 | expect: []byte(""), 101 | err: io.ErrUnexpectedEOF, 102 | n: 4, 103 | }, 104 | { 105 | data: append([]byte{0x02}, "ABC"...), 106 | expect: []byte("AB"), 107 | err: nil, 108 | n: 3, 109 | }, 110 | } 111 | for i, tc := range cases { 112 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 113 | res, n, err := parseVarIntBytes(tc.data) 114 | assert.Equal(t, tc.expect, res) 115 | assert.Equal(t, tc.n, n) 116 | if tc.err != nil { 117 | assert.Equal(t, tc.err, err) 118 | } else { 119 | assert.NoError(t, err) 120 | } 121 | }) 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /quicmoq/connection.go: -------------------------------------------------------------------------------- 1 | package quicmoq 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/mengelbart/moqtransport" 7 | "github.com/quic-go/quic-go" 8 | ) 9 | 10 | type connection struct { 11 | connection *quic.Conn 12 | perspective moqtransport.Perspective 13 | } 14 | 15 | func NewServer(conn *quic.Conn) moqtransport.Connection { 16 | return New(conn, moqtransport.PerspectiveServer) 17 | } 18 | 19 | func NewClient(conn *quic.Conn) moqtransport.Connection { 20 | return New(conn, moqtransport.PerspectiveClient) 21 | } 22 | 23 | func New(conn *quic.Conn, perspective moqtransport.Perspective) moqtransport.Connection { 24 | return &connection{conn, perspective} 25 | } 26 | 27 | func (c *connection) AcceptStream(ctx context.Context) (moqtransport.Stream, error) { 28 | s, err := c.connection.AcceptStream(ctx) 29 | if err != nil { 30 | return nil, err 31 | } 32 | return &Stream{ 33 | stream: s, 34 | }, nil 35 | } 36 | 37 | func (c *connection) AcceptUniStream(ctx context.Context) (moqtransport.ReceiveStream, error) { 38 | s, err := c.connection.AcceptUniStream(ctx) 39 | if err != nil { 40 | return nil, err 41 | } 42 | return &ReceiveStream{ 43 | stream: s, 44 | }, nil 45 | } 46 | 47 | func (c *connection) OpenStream() (moqtransport.Stream, error) { 48 | s, err := c.connection.OpenStream() 49 | if err != nil { 50 | return nil, err 51 | } 52 | return &Stream{ 53 | stream: s, 54 | }, nil 55 | } 56 | 57 | func (c *connection) OpenStreamSync(ctx context.Context) (moqtransport.Stream, error) { 58 | s, err := c.connection.OpenStreamSync(ctx) 59 | if err != nil { 60 | return nil, err 61 | } 62 | return &Stream{ 63 | stream: s, 64 | }, nil 65 | } 66 | 67 | func (c *connection) OpenUniStream() (moqtransport.SendStream, error) { 68 | s, err := c.connection.OpenUniStream() 69 | if err != nil { 70 | return nil, err 71 | } 72 | return &SendStream{ 73 | stream: s, 74 | }, nil 75 | } 76 | 77 | func (c *connection) OpenUniStreamSync(ctx context.Context) (moqtransport.SendStream, error) { 78 | s, err := c.connection.OpenUniStreamSync(ctx) 79 | if err != nil { 80 | return nil, err 81 | } 82 | return &SendStream{ 83 | stream: s, 84 | }, nil 85 | } 86 | 87 | func (c *connection) SendDatagram(b []byte) error { 88 | return c.connection.SendDatagram(b) 89 | } 90 | 91 | func (c *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) { 92 | return c.connection.ReceiveDatagram(ctx) 93 | } 94 | 95 | func (c *connection) CloseWithError(e uint64, msg string) error { 96 | return c.connection.CloseWithError(quic.ApplicationErrorCode(e), msg) 97 | } 98 | 99 | func (c *connection) Context() context.Context { 100 | return c.connection.Context() 101 | } 102 | 103 | func (c *connection) Protocol() moqtransport.Protocol { 104 | return moqtransport.ProtocolQUIC 105 | } 106 | 107 | func (c *connection) Perspective() moqtransport.Perspective { 108 | return c.perspective 109 | } 110 | -------------------------------------------------------------------------------- /integrationtests/fetch_test.go: -------------------------------------------------------------------------------- 1 | package integrationtests 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mengelbart/moqtransport" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestFetch(t *testing.T) { 13 | t.Run("success", func(t *testing.T) { 14 | sConn, cConn, cancel := connect(t) 15 | defer cancel() 16 | 17 | handler := moqtransport.HandlerFunc(func(w moqtransport.ResponseWriter, m *moqtransport.Message) { 18 | assert.Equal(t, moqtransport.MessageFetch, m.Method) 19 | assert.NotNil(t, w) 20 | assert.NoError(t, w.Accept()) 21 | }) 22 | _, ct, cancel := setup(t, sConn, cConn, handler) 23 | defer cancel() 24 | 25 | rt, err := ct.Fetch(context.Background(), []string{"namespace"}, "track") 26 | assert.NoError(t, err) 27 | assert.NotNil(t, rt) 28 | }) 29 | t.Run("auth_error", func(t *testing.T) { 30 | sConn, cConn, cancel := connect(t) 31 | defer cancel() 32 | 33 | handler := moqtransport.HandlerFunc(func(w moqtransport.ResponseWriter, m *moqtransport.Message) { 34 | assert.Equal(t, moqtransport.MessageFetch, m.Method) 35 | assert.NotNil(t, w) 36 | assert.NoError(t, w.Reject(uint64(moqtransport.ErrorCodeFetchUnauthorized), "unauthorized")) 37 | }) 38 | _, ct, cancel := setup(t, sConn, cConn, handler) 39 | defer cancel() 40 | 41 | rt, err := ct.Fetch(context.Background(), []string{"namespace"}, "track") 42 | assert.Error(t, err) 43 | assert.ErrorContains(t, err, "unauthorized") 44 | assert.Nil(t, rt) 45 | }) 46 | 47 | t.Run("receive_objects", func(t *testing.T) { 48 | sConn, cConn, cancel := connect(t) 49 | defer cancel() 50 | 51 | publisherCh := make(chan moqtransport.FetchPublisher, 1) 52 | 53 | handler := moqtransport.HandlerFunc(func(w moqtransport.ResponseWriter, m *moqtransport.Message) { 54 | assert.Equal(t, moqtransport.MessageFetch, m.Method) 55 | assert.NotNil(t, w) 56 | assert.NoError(t, w.Accept()) 57 | publisher, ok := w.(moqtransport.FetchPublisher) 58 | assert.True(t, ok) 59 | publisherCh <- publisher 60 | }) 61 | _, ct, cancel := setup(t, sConn, cConn, handler) 62 | defer cancel() 63 | 64 | rt, err := ct.Fetch(context.Background(), []string{"namespace"}, "track") 65 | assert.NoError(t, err) 66 | assert.NotNil(t, rt) 67 | 68 | var publisher moqtransport.FetchPublisher 69 | select { 70 | case publisher = <-publisherCh: 71 | case <-time.After(time.Second): 72 | assert.FailNow(t, "timeout while waiting for publisher") 73 | } 74 | 75 | fs, err := publisher.FetchStream() 76 | assert.NoError(t, err) 77 | n, err := fs.WriteObject(1, 2, 3, 0, []byte("hello fetch")) 78 | assert.NoError(t, err) 79 | assert.Equal(t, 11, n) 80 | assert.NoError(t, fs.Close()) 81 | 82 | ctx2, cancelCtx2 := context.WithTimeout(context.Background(), time.Second) 83 | defer cancelCtx2() 84 | 85 | o, err := rt.ReadObject(ctx2) 86 | assert.NoError(t, err) 87 | assert.Equal(t, &moqtransport.Object{ 88 | GroupID: 1, 89 | SubGroupID: 2, 90 | ObjectID: 3, 91 | Payload: []byte("hello fetch"), 92 | }, o) 93 | }) 94 | } 95 | -------------------------------------------------------------------------------- /webtransportmoq/connection.go: -------------------------------------------------------------------------------- 1 | package webtransportmoq 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/mengelbart/moqtransport" 7 | "github.com/quic-go/webtransport-go" 8 | ) 9 | 10 | type webTransportConn struct { 11 | session *webtransport.Session 12 | perspective moqtransport.Perspective 13 | } 14 | 15 | func NewServer(conn *webtransport.Session) moqtransport.Connection { 16 | return New(conn, moqtransport.PerspectiveServer) 17 | } 18 | 19 | func NewClient(conn *webtransport.Session) moqtransport.Connection { 20 | return New(conn, moqtransport.PerspectiveClient) 21 | } 22 | 23 | func New(session *webtransport.Session, perspective moqtransport.Perspective) moqtransport.Connection { 24 | return &webTransportConn{session, perspective} 25 | } 26 | 27 | func (c *webTransportConn) AcceptStream(ctx context.Context) (moqtransport.Stream, error) { 28 | s, err := c.session.AcceptStream(ctx) 29 | if err != nil { 30 | return nil, err 31 | } 32 | return &Stream{ 33 | stream: s, 34 | }, nil 35 | } 36 | 37 | func (c *webTransportConn) AcceptUniStream(ctx context.Context) (moqtransport.ReceiveStream, error) { 38 | s, err := c.session.AcceptUniStream(ctx) 39 | if err != nil { 40 | return nil, err 41 | } 42 | return &ReceiveStream{ 43 | stream: s, 44 | }, nil 45 | } 46 | 47 | func (c *webTransportConn) OpenStream() (moqtransport.Stream, error) { 48 | s, err := c.session.OpenStream() 49 | if err != nil { 50 | return nil, err 51 | } 52 | return &Stream{ 53 | stream: s, 54 | }, nil 55 | } 56 | 57 | func (c *webTransportConn) OpenStreamSync(ctx context.Context) (moqtransport.Stream, error) { 58 | s, err := c.session.OpenStreamSync(ctx) 59 | if err != nil { 60 | return nil, err 61 | } 62 | return &Stream{ 63 | stream: s, 64 | }, nil 65 | } 66 | 67 | func (c *webTransportConn) OpenUniStream() (moqtransport.SendStream, error) { 68 | s, err := c.session.OpenUniStream() 69 | if err != nil { 70 | return nil, err 71 | } 72 | return &SendStream{ 73 | stream: s, 74 | }, nil 75 | } 76 | 77 | func (c *webTransportConn) OpenUniStreamSync(ctx context.Context) (moqtransport.SendStream, error) { 78 | s, err := c.session.OpenUniStreamSync(ctx) 79 | if err != nil { 80 | return nil, err 81 | } 82 | return &SendStream{ 83 | stream: s, 84 | }, nil 85 | } 86 | 87 | func (c *webTransportConn) SendDatagram(b []byte) error { 88 | return c.session.SendDatagram(b) 89 | } 90 | 91 | func (c *webTransportConn) ReceiveDatagram(ctx context.Context) ([]byte, error) { 92 | return c.session.ReceiveDatagram(ctx) 93 | } 94 | 95 | func (c *webTransportConn) CloseWithError(e uint64, msg string) error { 96 | return c.session.CloseWithError(webtransport.SessionErrorCode(e), msg) 97 | } 98 | 99 | func (c *webTransportConn) Context() context.Context { 100 | return c.session.Context() 101 | } 102 | 103 | func (c *webTransportConn) Protocol() moqtransport.Protocol { 104 | return moqtransport.ProtocolWebTransport 105 | } 106 | 107 | func (c *webTransportConn) Perspective() moqtransport.Perspective { 108 | return c.perspective 109 | } 110 | -------------------------------------------------------------------------------- /internal/wire/control_message_parser.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/quic-go/quic-go/quicvarint" 9 | ) 10 | 11 | type ControlMessageParser struct { 12 | reader messageReader 13 | } 14 | 15 | func NewControlMessageParser(r io.Reader) *ControlMessageParser { 16 | return &ControlMessageParser{ 17 | reader: bufio.NewReader(r), 18 | } 19 | } 20 | 21 | func (p *ControlMessageParser) Parse() (ControlMessage, error) { 22 | mt, err := quicvarint.Read(p.reader) 23 | if err != nil { 24 | return nil, err 25 | } 26 | hi, err := p.reader.ReadByte() 27 | if err != nil { 28 | return nil, err 29 | } 30 | lo, err := p.reader.ReadByte() 31 | if err != nil { 32 | return nil, err 33 | } 34 | length := uint16(hi)<<8 | uint16(lo) 35 | 36 | msg := make([]byte, length) 37 | n, err := io.ReadFull(p.reader, msg) 38 | if err != nil { 39 | return nil, err 40 | } 41 | if n != int(length) { 42 | return nil, errLengthMismatch 43 | } 44 | 45 | var m ControlMessage 46 | switch controlMessageType(mt) { 47 | case messageTypeClientSetup: 48 | m = &ClientSetupMessage{} 49 | case messageTypeServerSetup: 50 | m = &ServerSetupMessage{} 51 | 52 | case messageTypeGoAway: 53 | m = &GoAwayMessage{} 54 | 55 | case messageTypeMaxRequestID: 56 | m = &MaxRequestIDMessage{} 57 | case messageTypeRequestsBlocked: 58 | m = &RequestsBlockedMessage{} 59 | 60 | case messageTypeSubscribe: 61 | m = &SubscribeMessage{} 62 | case messageTypeSubscribeOk: 63 | m = &SubscribeOkMessage{} 64 | case messageTypeSubscribeError: 65 | m = &SubscribeErrorMessage{} 66 | case messageTypeUnsubscribe: 67 | m = &UnsubscribeMessage{} 68 | case messageTypeSubscribeUpdate: 69 | m = &SubscribeUpdateMessage{} 70 | case messageTypeSubscribeDone: 71 | m = &SubscribeDoneMessage{} 72 | 73 | case messageTypeFetch: 74 | m = &FetchMessage{} 75 | case messageTypeFetchOk: 76 | m = &FetchOkMessage{} 77 | case messageTypeFetchError: 78 | m = &FetchErrorMessage{} 79 | case messageTypeFetchCancel: 80 | m = &FetchCancelMessage{} 81 | 82 | case messageTypeTrackStatus: 83 | m = &TrackStatusRequestMessage{} 84 | case messageTypeTrackStatusOk: 85 | m = &TrackStatusMessage{} 86 | 87 | case messageTypeAnnounce: 88 | m = &AnnounceMessage{} 89 | case messageTypeAnnounceOk: 90 | m = &AnnounceOkMessage{} 91 | case messageTypeAnnounceError: 92 | m = &AnnounceErrorMessage{} 93 | case messageTypeUnannounce: 94 | m = &UnannounceMessage{} 95 | case messageTypeAnnounceCancel: 96 | m = &AnnounceCancelMessage{} 97 | 98 | case messageTypeSubscribeNamespace: 99 | m = &SubscribeAnnouncesMessage{} 100 | case messageTypeSubscribeNamespaceOk: 101 | m = &SubscribeAnnouncesOkMessage{} 102 | case messageTypeSubscribeNamespaceError: 103 | m = &SubscribeAnnouncesErrorMessage{} 104 | case messageTypeUnsubscribeNamespace: 105 | m = &UnsubscribeAnnouncesMessage{} 106 | default: 107 | return nil, fmt.Errorf("%w: %v", errInvalidMessageType, mt) 108 | } 109 | err = m.parse(CurrentVersion, msg) 110 | return m, err 111 | } 112 | -------------------------------------------------------------------------------- /internal/wire/kvp_list.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/quic-go/quic-go/quicvarint" 9 | ) 10 | 11 | // Setup parameters 12 | const ( 13 | PathParameterKey = 0x01 14 | MaxRequestIDParameterKey = 0x02 15 | MaxAuthTokenCacheSizeParameterKey = 0x04 16 | ) 17 | 18 | // Version specific parameters 19 | const ( 20 | DeliveryTimeoutParameterKey = 0x02 21 | AuthorizationTokenParameterKey = 0x03 22 | MaxCacheDurationParameterKey = 0x04 23 | ) 24 | 25 | type KVPList []KeyValuePair 26 | 27 | func (pp KVPList) length() uint64 { 28 | length := uint64(0) 29 | for _, p := range pp { 30 | length += p.length() 31 | } 32 | return length 33 | } 34 | 35 | // Appends pp to buf with a prefix indicating the number of elements 36 | func (pp KVPList) appendNum(buf []byte) []byte { 37 | buf = quicvarint.Append(buf, uint64(len(pp))) 38 | return pp.append(buf) 39 | } 40 | 41 | // Appends pp to buf with a prefix indicating the length in bytes 42 | func (pp KVPList) appendLength(buf []byte) []byte { 43 | buf = quicvarint.Append(buf, pp.length()) 44 | return pp.append(buf) 45 | } 46 | 47 | func (pp KVPList) append(buf []byte) []byte { 48 | for _, p := range pp { 49 | buf = p.append(buf) 50 | } 51 | return buf 52 | } 53 | 54 | func (pp KVPList) String() string { 55 | res := "[" 56 | i := 0 57 | for _, v := range pp { 58 | if i < len(pp)-1 { 59 | res += fmt.Sprintf("%v, ", v) 60 | } else { 61 | res += fmt.Sprintf("%v", v) 62 | } 63 | i++ 64 | } 65 | return res + "]" 66 | } 67 | 68 | func (pp *KVPList) parseLengthReader(br *bufio.Reader) error { 69 | length, err := quicvarint.Read(br) 70 | if err != nil { 71 | return err 72 | } 73 | if length == 0 { 74 | return nil 75 | } 76 | lr := io.LimitReader(br, int64(length)) 77 | lbr := bufio.NewReader(quicvarint.NewReader(lr)) 78 | for { 79 | var hdrExt KeyValuePair 80 | if err = hdrExt.parseReader(lbr); err != nil { 81 | return err 82 | } 83 | *pp = append(*pp, hdrExt) 84 | } 85 | } 86 | 87 | // Parses pp from data based on a length prefix in number of elements 88 | func (pp *KVPList) parseNum(data []byte) error { 89 | numParameters, n, err := quicvarint.Parse(data) 90 | if err != nil { 91 | return err 92 | } 93 | data = data[n:] 94 | 95 | for i := uint64(0); i < numParameters; i++ { 96 | param := KeyValuePair{} 97 | n, err := param.parse(data) 98 | if err != nil { 99 | return err 100 | } 101 | data = data[n:] 102 | *pp = append(*pp, param) 103 | } 104 | return nil 105 | } 106 | 107 | // Parses pp from data based on a length prefix in bytes 108 | func (pp *KVPList) parseLength(data []byte) (parsed int, err error) { 109 | length, n, err := quicvarint.Parse(data) 110 | parsed += n 111 | if err != nil { 112 | return 113 | } 114 | data = data[n:] 115 | data = data[:length] 116 | 117 | for len(data) > 0 { 118 | var hdrExt KeyValuePair 119 | n, err = hdrExt.parse(data) 120 | parsed += n 121 | if err != nil { 122 | return parsed, err 123 | } 124 | *pp = append(*pp, hdrExt) 125 | } 126 | return 127 | } 128 | -------------------------------------------------------------------------------- /internal/wire/subscribe_ok_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | "time" 6 | 7 | "github.com/quic-go/quic-go/quicvarint" 8 | ) 9 | 10 | type SubscribeOkMessage struct { 11 | RequestID uint64 12 | TrackAlias uint64 13 | Expires time.Duration 14 | GroupOrder uint8 15 | ContentExists bool 16 | LargestLocation Location 17 | Parameters KVPList 18 | } 19 | 20 | func (m *SubscribeOkMessage) LogValue() slog.Value { 21 | ce := 0 22 | if m.ContentExists { 23 | ce = 1 24 | } 25 | attrs := []slog.Attr{ 26 | slog.String("type", "subscribe_ok"), 27 | slog.Uint64("track_alias", m.TrackAlias), 28 | slog.Uint64("request_id", m.RequestID), 29 | slog.Uint64("expires", uint64(m.Expires.Milliseconds())), 30 | slog.Any("group_order", m.GroupOrder), 31 | slog.Int("content_exists", ce), 32 | } 33 | if m.ContentExists { 34 | attrs = append(attrs, 35 | slog.Uint64("largest_group_id", m.LargestLocation.Group), 36 | slog.Uint64("largest_object_id", m.LargestLocation.Object), 37 | ) 38 | } 39 | attrs = append(attrs, 40 | slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), 41 | ) 42 | if len(m.Parameters) > 0 { 43 | attrs = append(attrs, 44 | slog.Any("subscribe_parameters", m.Parameters), 45 | ) 46 | 47 | } 48 | return slog.GroupValue(attrs...) 49 | } 50 | 51 | func (m SubscribeOkMessage) Type() controlMessageType { 52 | return messageTypeSubscribeOk 53 | } 54 | 55 | func (m *SubscribeOkMessage) Append(buf []byte) []byte { 56 | buf = quicvarint.Append(buf, m.RequestID) 57 | buf = quicvarint.Append(buf, m.TrackAlias) 58 | buf = quicvarint.Append(buf, uint64(m.Expires)) 59 | buf = append(buf, m.GroupOrder) 60 | if m.ContentExists { 61 | buf = append(buf, 1) // ContentExists=true 62 | buf = m.LargestLocation.append(buf) 63 | return m.Parameters.appendNum(buf) 64 | } 65 | buf = append(buf, 0) // ContentExists=false 66 | return m.Parameters.appendNum(buf) 67 | } 68 | 69 | func (m *SubscribeOkMessage) parse(v Version, data []byte) (err error) { 70 | var n int 71 | m.RequestID, n, err = quicvarint.Parse(data) 72 | if err != nil { 73 | return 74 | } 75 | data = data[n:] 76 | 77 | m.TrackAlias, n, err = quicvarint.Parse(data) 78 | if err != nil { 79 | return 80 | } 81 | data = data[n:] 82 | 83 | expires, n, err := quicvarint.Parse(data) 84 | if err != nil { 85 | return 86 | } 87 | m.Expires = time.Duration(expires) * time.Millisecond 88 | data = data[n:] 89 | 90 | if len(data) < 2 { 91 | return errLengthMismatch 92 | } 93 | m.GroupOrder = data[0] 94 | if m.GroupOrder > 2 { 95 | return errInvalidGroupOrder 96 | } 97 | if data[1] != 0 && data[1] != 1 { 98 | return errInvalidContentExistsByte 99 | } 100 | m.ContentExists = data[1] == 1 101 | data = data[2:] 102 | 103 | if !m.ContentExists { 104 | m.Parameters = KVPList{} 105 | return m.Parameters.parseNum(data) 106 | } 107 | 108 | n, err = m.LargestLocation.parse(v, data) 109 | if err != nil { 110 | return err 111 | } 112 | data = data[n:] 113 | 114 | m.Parameters = KVPList{} 115 | return m.Parameters.parseNum(data) 116 | } 117 | -------------------------------------------------------------------------------- /internal/wire/key_value_pair_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestKeyValuePairAppend(t *testing.T) { 12 | cases := []struct { 13 | p KeyValuePair 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | p: KeyValuePair{ 19 | Type: 1, 20 | ValueBytes: []byte(""), 21 | }, 22 | buf: nil, 23 | expect: []byte{0x01, 0x00}, 24 | }, 25 | { 26 | p: KeyValuePair{ 27 | Type: 1, 28 | ValueBytes: []byte("A"), 29 | }, 30 | buf: nil, 31 | expect: []byte{0x01, 0x01, 'A'}, 32 | }, 33 | { 34 | p: KeyValuePair{ 35 | Type: 1, 36 | ValueBytes: []byte("A"), 37 | }, 38 | buf: []byte{0x01, 0x02}, 39 | expect: []byte{0x01, 0x02, 0x01, 0x01, 'A'}, 40 | }, 41 | { 42 | p: KeyValuePair{ 43 | Type: 2, 44 | ValueVarInt: uint64(1), 45 | }, 46 | buf: nil, 47 | expect: []byte{0x02, 0x01}, 48 | }, 49 | { 50 | p: KeyValuePair{ 51 | Type: MaxRequestIDParameterKey, 52 | ValueVarInt: uint64(2), 53 | }, 54 | buf: []byte{}, 55 | expect: []byte{0x02, 0x02}, 56 | }, 57 | { 58 | p: KeyValuePair{ 59 | Type: MaxRequestIDParameterKey, 60 | ValueVarInt: uint64(3), 61 | }, 62 | buf: []byte{0x01, 0x02}, 63 | expect: []byte{0x01, 0x02, 0x02, 0x03}, 64 | }, 65 | } 66 | for i, tc := range cases { 67 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 68 | res := tc.p.append(tc.buf) 69 | assert.Equal(t, tc.expect, res) 70 | }) 71 | } 72 | } 73 | 74 | func TestParseKeyValuePair(t *testing.T) { 75 | cases := []struct { 76 | data []byte 77 | expect KeyValuePair 78 | err error 79 | n int 80 | }{ 81 | { 82 | data: []byte{byte(MaxRequestIDParameterKey), 0x01}, 83 | expect: KeyValuePair{ 84 | Type: MaxRequestIDParameterKey, 85 | ValueVarInt: uint64(1), 86 | }, 87 | err: nil, 88 | n: 2, 89 | }, 90 | { 91 | data: append([]byte{byte(PathParameterKey), 11}, "/path/param"...), 92 | expect: KeyValuePair{ 93 | Type: 1, 94 | ValueBytes: []byte("/path/param"), 95 | }, 96 | err: nil, 97 | n: 13, 98 | }, 99 | { 100 | data: []byte{}, 101 | expect: KeyValuePair{}, 102 | err: io.EOF, 103 | n: 0, 104 | }, 105 | { 106 | data: []byte{0x05, 0x01, 0x00}, 107 | expect: KeyValuePair{ 108 | Type: 5, 109 | ValueBytes: []byte{0x00}, 110 | }, 111 | err: nil, 112 | n: 3, 113 | }, 114 | { 115 | data: []byte{0x01, 0x01, 'A'}, 116 | expect: KeyValuePair{ 117 | Type: PathParameterKey, 118 | ValueBytes: []byte("A"), 119 | }, 120 | err: nil, 121 | n: 3, 122 | }, 123 | } 124 | for i, tc := range cases { 125 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 126 | res := KeyValuePair{} 127 | n, err := res.parse(tc.data) 128 | assert.Equal(t, tc.expect, res) 129 | assert.Equal(t, tc.n, n) 130 | if tc.err != nil { 131 | assert.Error(t, err) 132 | assert.Equal(t, tc.err, err) 133 | } else { 134 | assert.NoError(t, err) 135 | } 136 | }) 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /internal/wire/object_datagram_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/quic-go/quic-go/quicvarint" 7 | ) 8 | 9 | const ( 10 | objectTypeDatagram uint64 = 0x00 11 | objectTypeDatagramExtension uint64 = 0x01 12 | objectTypeDatagramStatus uint64 = 0x02 13 | objectTypeDatagramStatusExtension uint64 = 0x03 14 | ) 15 | 16 | type ObjectDatagramMessage struct { 17 | TrackAlias uint64 18 | GroupID uint64 19 | ObjectID uint64 20 | PublisherPriority uint8 21 | ObjectExtensionHeaders KVPList 22 | ObjectPayload []byte 23 | ObjectStatus ObjectStatus 24 | } 25 | 26 | func (m *ObjectDatagramMessage) AppendDatagram(buf []byte) []byte { 27 | typ := objectTypeDatagram 28 | if m.ObjectExtensionHeaders != nil { 29 | typ = objectTypeDatagramExtension 30 | } 31 | buf = quicvarint.Append(buf, typ) 32 | buf = quicvarint.Append(buf, m.TrackAlias) 33 | buf = quicvarint.Append(buf, m.GroupID) 34 | buf = quicvarint.Append(buf, m.ObjectID) 35 | buf = append(buf, m.PublisherPriority) 36 | if typ == objectTypeDatagramExtension { 37 | buf = m.ObjectExtensionHeaders.appendLength(buf) 38 | } 39 | return append(buf, m.ObjectPayload...) 40 | } 41 | 42 | func (m *ObjectDatagramMessage) AppendDatagramStatus(buf []byte) []byte { 43 | typ := objectTypeDatagramStatus 44 | if m.ObjectExtensionHeaders != nil { 45 | typ = objectTypeDatagramStatusExtension 46 | } 47 | buf = quicvarint.Append(buf, typ) 48 | buf = quicvarint.Append(buf, m.TrackAlias) 49 | buf = quicvarint.Append(buf, m.GroupID) 50 | buf = quicvarint.Append(buf, m.ObjectID) 51 | buf = append(buf, m.PublisherPriority) 52 | if typ == objectTypeDatagramExtension { 53 | buf = m.ObjectExtensionHeaders.appendLength(buf) 54 | } 55 | return quicvarint.Append(buf, uint64(m.ObjectStatus)) 56 | } 57 | 58 | func (m *ObjectDatagramMessage) Parse(data []byte) (parsed int, err error) { 59 | var n int 60 | var typ uint64 61 | typ, n, err = quicvarint.Parse(data) 62 | parsed += n 63 | if err != nil { 64 | return parsed, err 65 | } 66 | data = data[n:] 67 | 68 | m.TrackAlias, n, err = quicvarint.Parse(data) 69 | parsed += n 70 | if err != nil { 71 | return 72 | } 73 | data = data[n:] 74 | 75 | m.GroupID, n, err = quicvarint.Parse(data) 76 | parsed += n 77 | if err != nil { 78 | return 79 | } 80 | data = data[n:] 81 | 82 | m.ObjectID, n, err = quicvarint.Parse(data) 83 | parsed += n 84 | if err != nil { 85 | return 86 | } 87 | data = data[n:] 88 | 89 | if len(data) == 0 { 90 | return parsed, io.ErrUnexpectedEOF 91 | } 92 | m.PublisherPriority = data[0] 93 | parsed += 1 94 | data = data[1:] 95 | 96 | if typ&0x01 == 1 { 97 | m.ObjectExtensionHeaders = KVPList{} 98 | n, err = m.ObjectExtensionHeaders.parseLength(data) 99 | parsed += n 100 | if err != nil { 101 | return parsed, err 102 | } 103 | } 104 | if typ&0x02 == 0 { 105 | m.ObjectPayload = make([]byte, len(data)) 106 | n = copy(m.ObjectPayload, data) 107 | parsed += n 108 | } else { 109 | var status uint64 110 | status, n, err = quicvarint.Parse(data) 111 | parsed += n 112 | m.ObjectStatus = ObjectStatus(status) 113 | } 114 | return 115 | } 116 | -------------------------------------------------------------------------------- /subscribe_response_writer.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import "time" 4 | 5 | type SubscribeResponseWriter struct { 6 | id uint64 7 | trackAlias uint64 8 | session *Session 9 | localTrack *localTrack 10 | handled bool 11 | } 12 | 13 | // SubscribeOKOption is a functional option for configuring SUBSCRIBE_OK responses. 14 | type SubscribeOKOption func(*SubscribeOkOptions) 15 | 16 | // WithExpires sets the subscription expiration duration. 17 | // A duration of 0 means the subscription never expires (default). 18 | func WithExpires(expires time.Duration) SubscribeOKOption { 19 | return func(opts *SubscribeOkOptions) { 20 | opts.Expires = expires 21 | } 22 | } 23 | 24 | // WithGroupOrder sets the group order for the subscription. 25 | // Default is GroupOrderAscending. 26 | func WithGroupOrder(groupOrder GroupOrder) SubscribeOKOption { 27 | return func(opts *SubscribeOkOptions) { 28 | opts.GroupOrder = groupOrder 29 | } 30 | } 31 | 32 | // WithLargestLocation sets the largest available location for the track. 33 | // When set, ContentExists is automatically set to true. 34 | // When nil, ContentExists is automatically set to false. 35 | func WithLargestLocation(location *Location) SubscribeOKOption { 36 | return func(opts *SubscribeOkOptions) { 37 | opts.LargestLocation = location 38 | opts.ContentExists = location != nil 39 | } 40 | } 41 | 42 | // WithParameters sets additional key-value parameters for the response. 43 | func WithParameters(parameters KVPList) SubscribeOKOption { 44 | return func(opts *SubscribeOkOptions) { 45 | opts.Parameters = parameters 46 | } 47 | } 48 | 49 | // Accept accepts the subscription with the given options. 50 | // 51 | // Default behavior when no options are provided: 52 | // - Expires: 0 (never expires) 53 | // - GroupOrder: GroupOrderAscending 54 | // - ContentExists: false (no content available) 55 | // - LargestLocation: nil 56 | // - Parameters: empty 57 | // 58 | // Use WithLargestLocation to indicate content exists and provide the largest location. 59 | // ContentExists is automatically set based on whether LargestLocation is provided. 60 | func (w *SubscribeResponseWriter) Accept(options ...SubscribeOKOption) error { 61 | w.handled = true 62 | 63 | // Set default values 64 | opts := &SubscribeOkOptions{ 65 | Expires: 0, 66 | GroupOrder: GroupOrderAscending, 67 | ContentExists: false, 68 | LargestLocation: nil, 69 | Parameters: KVPList{}, 70 | } 71 | 72 | // Apply options 73 | for _, option := range options { 74 | option(opts) 75 | } 76 | 77 | if err := w.session.acceptSubscriptionWithOptions(w.id, opts); err != nil { 78 | return err 79 | } 80 | return nil 81 | } 82 | 83 | func (w *SubscribeResponseWriter) Reject(code ErrorCodeSubscribe, reason string) error { 84 | w.handled = true 85 | return w.session.rejectSubscription(w.id, code, reason) 86 | } 87 | 88 | func (w *SubscribeResponseWriter) SendDatagram(o Object) error { 89 | return w.localTrack.sendDatagram(o) 90 | } 91 | 92 | func (w *SubscribeResponseWriter) OpenSubgroup(groupID, subgroupID uint64, priority uint8) (*Subgroup, error) { 93 | return w.localTrack.openSubgroup(groupID, subgroupID, priority) 94 | } 95 | 96 | func (w *SubscribeResponseWriter) CloseWithError(code uint64, reason string) error { 97 | return w.localTrack.close(code, reason) 98 | } 99 | -------------------------------------------------------------------------------- /internal/wire/object_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | 7 | "github.com/quic-go/quic-go/quicvarint" 8 | ) 9 | 10 | type ObjectMessage struct { 11 | TrackAlias uint64 12 | GroupID uint64 13 | SubgroupID uint64 14 | ObjectID uint64 15 | PublisherPriority uint8 16 | ObjectExtensionHeaders KVPList 17 | ObjectStatus ObjectStatus 18 | ObjectPayload []byte 19 | } 20 | 21 | func (m *ObjectMessage) AppendSubgroup(buf []byte) []byte { 22 | buf = quicvarint.Append(buf, m.ObjectID) 23 | buf = m.ObjectExtensionHeaders.appendLength(buf) 24 | buf = quicvarint.Append(buf, uint64(len(m.ObjectPayload))) 25 | if len(m.ObjectPayload) == 0 { 26 | buf = quicvarint.Append(buf, uint64(m.ObjectStatus)) 27 | } else { 28 | buf = append(buf, m.ObjectPayload...) 29 | } 30 | return buf 31 | } 32 | 33 | func (m *ObjectMessage) AppendFetch(buf []byte) []byte { 34 | buf = quicvarint.Append(buf, m.GroupID) 35 | buf = quicvarint.Append(buf, m.SubgroupID) 36 | buf = quicvarint.Append(buf, m.ObjectID) 37 | buf = append(buf, m.PublisherPriority) 38 | buf = m.ObjectExtensionHeaders.appendLength(buf) 39 | buf = quicvarint.Append(buf, uint64(len(m.ObjectPayload))) 40 | if len(m.ObjectPayload) == 0 { 41 | buf = quicvarint.Append(buf, uint64(m.ObjectStatus)) 42 | } else { 43 | buf = append(buf, m.ObjectPayload...) 44 | } 45 | return buf 46 | } 47 | 48 | func (m *ObjectMessage) readSubgroup(r io.Reader) (err error) { 49 | br := bufio.NewReader(r) 50 | m.ObjectID, err = quicvarint.Read(br) 51 | if err != nil { 52 | return 53 | } 54 | 55 | if m.ObjectExtensionHeaders != nil { 56 | if err = m.ObjectExtensionHeaders.parseLengthReader(br); err != nil { 57 | return err 58 | } 59 | } 60 | 61 | length, err := quicvarint.Read(br) 62 | if err != nil { 63 | return 64 | } 65 | if length == 0 { 66 | var status uint64 67 | status, err = quicvarint.Read(br) 68 | if err != nil { 69 | return 70 | } 71 | m.ObjectStatus = ObjectStatus(status) 72 | return 73 | } 74 | m.ObjectPayload = make([]byte, length) 75 | _, err = io.ReadFull(r, m.ObjectPayload) 76 | return 77 | } 78 | 79 | func (m *ObjectMessage) readFetch(r io.Reader) (err error) { 80 | br := bufio.NewReader(r) 81 | m.GroupID, err = quicvarint.Read(br) 82 | if err != nil { 83 | return 84 | } 85 | m.SubgroupID, err = quicvarint.Read(br) 86 | if err != nil { 87 | return 88 | } 89 | m.ObjectID, err = quicvarint.Read(br) 90 | if err != nil { 91 | return 92 | } 93 | m.PublisherPriority, err = br.ReadByte() 94 | if err != nil { 95 | return 96 | } 97 | if m.ObjectExtensionHeaders == nil { 98 | m.ObjectExtensionHeaders = KVPList{} 99 | } 100 | if err = m.ObjectExtensionHeaders.parseLengthReader(br); err != nil { 101 | return err 102 | } 103 | 104 | length, err := quicvarint.Read(br) 105 | if err != nil { 106 | return 107 | } 108 | 109 | if length == 0 { 110 | var status uint64 111 | status, err = quicvarint.Read(br) 112 | if err != nil { 113 | return 114 | } 115 | m.ObjectStatus = ObjectStatus(status) 116 | return 117 | } 118 | 119 | m.ObjectPayload = make([]byte, length) 120 | _, err = io.ReadFull(r, m.ObjectPayload) 121 | return 122 | } 123 | -------------------------------------------------------------------------------- /remote_track_map.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | ) 8 | 9 | type errRequestsBlocked struct { 10 | maxRequestID uint64 11 | } 12 | 13 | func (e errRequestsBlocked) Error() string { 14 | return fmt.Sprintf("too many subscribes, max_request_id=%v", e.maxRequestID) 15 | } 16 | 17 | var ( 18 | errUnknownRequestID = errors.New("unknown request ID") 19 | errDuplicateRequestIDBug = errors.New("internal error: duplicate request ID") 20 | errDuplicateTrackAliasBug = errors.New("internal error: duplicate track alias") 21 | ) 22 | 23 | type remoteTrackMap struct { 24 | lock sync.Mutex 25 | nextTrackAlias uint64 26 | pending map[uint64]*RemoteTrack 27 | open map[uint64]*RemoteTrack 28 | trackAliasToRequestID map[uint64]uint64 29 | } 30 | 31 | func newRemoteTrackMap() *remoteTrackMap { 32 | return &remoteTrackMap{ 33 | lock: sync.Mutex{}, 34 | nextTrackAlias: 0, 35 | pending: map[uint64]*RemoteTrack{}, 36 | open: map[uint64]*RemoteTrack{}, 37 | trackAliasToRequestID: map[uint64]uint64{}, 38 | } 39 | } 40 | 41 | func (m *remoteTrackMap) findByRequestID(id uint64) (*RemoteTrack, bool) { 42 | m.lock.Lock() 43 | defer m.lock.Unlock() 44 | sub, ok := m.open[id] 45 | if !ok { 46 | sub, ok = m.pending[id] 47 | } 48 | if !ok { 49 | return nil, false 50 | } 51 | return sub, true 52 | } 53 | 54 | func (m *remoteTrackMap) addPending(requestID uint64, rt *RemoteTrack) error { 55 | m.lock.Lock() 56 | defer m.lock.Unlock() 57 | if _, ok := m.pending[requestID]; ok { 58 | // Should never happen 59 | return errDuplicateRequestIDBug 60 | } 61 | if _, ok := m.open[requestID]; ok { 62 | // Should never happen 63 | return errDuplicateRequestIDBug 64 | } 65 | m.pending[requestID] = rt 66 | return nil 67 | } 68 | 69 | func (m *remoteTrackMap) addPendingWithAlias(requestID uint64, rt *RemoteTrack) error { 70 | m.lock.Lock() 71 | defer m.lock.Unlock() 72 | if _, ok := m.pending[requestID]; ok { 73 | return errDuplicateRequestIDBug 74 | } 75 | if _, ok := m.open[requestID]; ok { 76 | return errDuplicateRequestIDBug 77 | } 78 | m.pending[requestID] = rt 79 | return nil 80 | } 81 | 82 | func (m *remoteTrackMap) setAlias(id, alias uint64) error { 83 | m.lock.Lock() 84 | defer m.lock.Unlock() 85 | if _, ok := m.trackAliasToRequestID[alias]; ok { 86 | return errDuplicateTrackAliasBug 87 | } 88 | m.trackAliasToRequestID[alias] = id 89 | return nil 90 | } 91 | 92 | func (m *remoteTrackMap) confirm(id uint64) (*RemoteTrack, error) { 93 | m.lock.Lock() 94 | defer m.lock.Unlock() 95 | s, ok := m.pending[id] 96 | if !ok { 97 | return nil, errUnknownRequestID 98 | } 99 | delete(m.pending, id) 100 | m.open[id] = s 101 | return s, nil 102 | } 103 | 104 | func (m *remoteTrackMap) reject(id uint64) (*RemoteTrack, bool) { 105 | m.lock.Lock() 106 | defer m.lock.Unlock() 107 | s, ok := m.pending[id] 108 | if !ok { 109 | return nil, false 110 | } 111 | delete(m.pending, id) 112 | return s, true 113 | } 114 | 115 | func (m *remoteTrackMap) findByTrackAlias(alias uint64) (*RemoteTrack, bool) { 116 | m.lock.Lock() 117 | id, ok := m.trackAliasToRequestID[alias] 118 | m.lock.Unlock() 119 | if !ok { 120 | return nil, false 121 | } 122 | return m.findByRequestID(id) 123 | } 124 | -------------------------------------------------------------------------------- /internal/wire/subscribe_update_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSubscribeUpdateMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | sum SubscribeUpdateMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | sum: SubscribeUpdateMessage{ 19 | RequestID: 0, 20 | StartLocation: Location{ 21 | Group: 0, 22 | Object: 0, 23 | }, 24 | EndGroup: 0, 25 | SubscriberPriority: 0, 26 | Forward: 0, 27 | Parameters: KVPList{}, 28 | }, 29 | buf: []byte{}, 30 | expect: []byte{ 31 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 32 | }, 33 | }, 34 | { 35 | sum: SubscribeUpdateMessage{ 36 | RequestID: 1, 37 | StartLocation: Location{ 38 | Group: 2, 39 | Object: 3, 40 | }, 41 | EndGroup: 4, 42 | SubscriberPriority: 5, 43 | Forward: 1, 44 | Parameters: KVPList{KeyValuePair{Type: PathParameterKey, ValueBytes: []byte("A")}}, 45 | }, 46 | buf: []byte{}, 47 | expect: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x01, 0x01, 0x01, 0x01, 'A'}, 48 | }, 49 | { 50 | sum: SubscribeUpdateMessage{ 51 | RequestID: 1, 52 | StartLocation: Location{ 53 | Group: 2, 54 | Object: 3, 55 | }, 56 | EndGroup: 4, 57 | SubscriberPriority: 5, 58 | Forward: 1, 59 | Parameters: KVPList{KeyValuePair{Type: PathParameterKey, ValueBytes: []byte("A")}}, 60 | }, 61 | buf: []byte{0x0a, 0x0b}, 62 | expect: []byte{0x0a, 0x0b, 0x01, 0x02, 0x03, 0x04, 0x05, 0x01, 0x01, 0x01, 0x01, 'A'}, 63 | }, 64 | } 65 | for i, tc := range cases { 66 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 67 | res := tc.sum.Append(tc.buf) 68 | assert.Equal(t, tc.expect, res) 69 | }) 70 | } 71 | } 72 | 73 | func TestParseSubscribeUpdateMessage(t *testing.T) { 74 | cases := []struct { 75 | data []byte 76 | expect *SubscribeUpdateMessage 77 | err error 78 | }{ 79 | { 80 | data: nil, 81 | expect: &SubscribeUpdateMessage{}, 82 | err: io.EOF, 83 | }, 84 | { 85 | data: []byte{}, 86 | expect: &SubscribeUpdateMessage{}, 87 | err: io.EOF, 88 | }, 89 | { 90 | data: []byte{0x00, 0x01, 0x02}, 91 | expect: &SubscribeUpdateMessage{ 92 | RequestID: 0, 93 | StartLocation: Location{ 94 | Group: 1, 95 | Object: 2, 96 | }, 97 | EndGroup: 0, 98 | SubscriberPriority: 0, 99 | Forward: 0, 100 | Parameters: nil, 101 | }, 102 | err: io.EOF, 103 | }, 104 | { 105 | data: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x01, 0x01, 0x01, 0x01, 'P'}, 106 | expect: &SubscribeUpdateMessage{ 107 | RequestID: 1, 108 | StartLocation: Location{ 109 | Group: 2, 110 | Object: 3, 111 | }, 112 | EndGroup: 4, 113 | SubscriberPriority: 5, 114 | Forward: 1, 115 | Parameters: KVPList{KeyValuePair{Type: PathParameterKey, ValueBytes: []byte("P")}}, 116 | }, 117 | err: nil, 118 | }, 119 | } 120 | for i, tc := range cases { 121 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 122 | res := &SubscribeUpdateMessage{} 123 | err := res.parse(CurrentVersion, tc.data) 124 | assert.Equal(t, tc.expect, res) 125 | if tc.err != nil { 126 | assert.Equal(t, tc.err, err) 127 | } else { 128 | assert.NoError(t, err) 129 | } 130 | }) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /internal/wire/server_setup_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestServerSetupMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | ssm ServerSetupMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | ssm: ServerSetupMessage{ 19 | SelectedVersion: 0, 20 | SetupParameters: nil, 21 | }, 22 | buf: []byte{}, 23 | expect: []byte{ 24 | 0x00, 0x00, 25 | }, 26 | }, 27 | { 28 | ssm: ServerSetupMessage{ 29 | SelectedVersion: 0, 30 | SetupParameters: KVPList{}, 31 | }, 32 | buf: []byte{}, 33 | expect: []byte{ 34 | 0x00, 0x00, 35 | }, 36 | }, 37 | { 38 | ssm: ServerSetupMessage{ 39 | SelectedVersion: 0, 40 | SetupParameters: KVPList{ 41 | KeyValuePair{ 42 | Type: MaxRequestIDParameterKey, 43 | ValueVarInt: 2, 44 | }, 45 | }, 46 | }, 47 | buf: []byte{}, 48 | expect: []byte{ 49 | 0x00, 0x01, 0x02, 0x02, 50 | }, 51 | }, 52 | { 53 | ssm: ServerSetupMessage{ 54 | SelectedVersion: 0, 55 | SetupParameters: KVPList{KeyValuePair{ 56 | Type: PathParameterKey, 57 | ValueBytes: []byte("A"), 58 | }}, 59 | }, 60 | buf: []byte{0x01, 0x02}, 61 | expect: []byte{0x01, 0x02, 62 | 0x00, 0x01, 0x01, 0x01, 'A', 63 | }, 64 | }, 65 | } 66 | for i, tc := range cases { 67 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 68 | res := tc.ssm.Append(tc.buf) 69 | assert.Equal(t, tc.expect, res) 70 | }) 71 | } 72 | } 73 | 74 | func TestParseServerSetupMessage(t *testing.T) { 75 | cases := []struct { 76 | data []byte 77 | expect *ServerSetupMessage 78 | err error 79 | }{ 80 | { 81 | data: nil, 82 | expect: &ServerSetupMessage{}, 83 | err: io.EOF, 84 | }, 85 | { 86 | data: []byte{}, 87 | expect: &ServerSetupMessage{}, 88 | err: io.EOF, 89 | }, 90 | { 91 | data: []byte{ 92 | 0x00, 0x01, 93 | }, 94 | expect: &ServerSetupMessage{ 95 | SelectedVersion: 0, 96 | SetupParameters: KVPList{}, 97 | }, 98 | err: io.EOF, 99 | }, 100 | { 101 | data: []byte{ 102 | 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 103 | }, 104 | expect: &ServerSetupMessage{ 105 | SelectedVersion: Draft_ietf_moq_transport_00, 106 | SetupParameters: KVPList{}, 107 | }, 108 | err: nil, 109 | }, 110 | { 111 | data: []byte{ 112 | 0x00, 0x01, 0x01, 0x01, 'A', 113 | }, 114 | expect: &ServerSetupMessage{ 115 | SelectedVersion: 0, 116 | SetupParameters: KVPList{KeyValuePair{ 117 | Type: PathParameterKey, 118 | ValueBytes: []byte("A"), 119 | }}, 120 | }, 121 | err: nil, 122 | }, 123 | { 124 | data: []byte{ 125 | 0x00, 0x01, 0x01, 0x01, 'A', 0x0a, 0x0b, 0x0c, 0x0d, 126 | }, 127 | expect: &ServerSetupMessage{ 128 | SelectedVersion: 0, 129 | SetupParameters: KVPList{KeyValuePair{ 130 | Type: PathParameterKey, 131 | ValueBytes: []byte("A"), 132 | }}, 133 | }, 134 | err: nil, 135 | }, 136 | } 137 | for i, tc := range cases { 138 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 139 | res := &ServerSetupMessage{} 140 | err := res.parse(CurrentVersion, tc.data) 141 | if tc.err != nil { 142 | assert.Equal(t, tc.err, err) 143 | assert.Equal(t, tc.expect, res) 144 | return 145 | } 146 | assert.NoError(t, err) 147 | assert.Equal(t, tc.expect, res) 148 | }) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /integrationtests/utils.go: -------------------------------------------------------------------------------- 1 | package integrationtests 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "crypto/rsa" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "encoding/pem" 10 | "fmt" 11 | "math/big" 12 | "net" 13 | "sync" 14 | "testing" 15 | 16 | "github.com/mengelbart/moqtransport" 17 | "github.com/mengelbart/moqtransport/quicmoq" 18 | "github.com/quic-go/quic-go" 19 | "github.com/stretchr/testify/assert" 20 | ) 21 | 22 | func connect(t *testing.T) (server, client *quic.Conn, cancel func()) { 23 | tlsConfig, err := generateTLSConfig() 24 | assert.NoError(t, err) 25 | listener, err := quic.ListenAddr("localhost:0", tlsConfig, &quic.Config{ 26 | EnableDatagrams: true, 27 | }) 28 | assert.NoError(t, err) 29 | 30 | clientConn, err := quic.DialAddr(context.Background(), fmt.Sprintf("localhost:%d", listener.Addr().(*net.UDPAddr).Port), &tls.Config{ 31 | InsecureSkipVerify: true, 32 | NextProtos: []string{"moq-00"}, 33 | }, &quic.Config{ 34 | EnableDatagrams: true, 35 | }) 36 | assert.NoError(t, err) 37 | 38 | serverConn, err := listener.Accept(context.Background()) 39 | assert.NoError(t, err) 40 | 41 | return serverConn, clientConn, func() { 42 | listener.Close() 43 | assert.NoError(t, clientConn.CloseWithError(0, "")) 44 | assert.NoError(t, serverConn.CloseWithError(0, "")) 45 | } 46 | } 47 | 48 | func setup(t *testing.T, sConn, cConn *quic.Conn, handler moqtransport.Handler) ( 49 | serverSession *moqtransport.Session, 50 | clientSession *moqtransport.Session, 51 | cancel func(), 52 | ) { 53 | return setupWithHandlers(t, sConn, cConn, handler, nil) 54 | } 55 | 56 | func setupWithHandlers(t *testing.T, sConn, cConn *quic.Conn, handler moqtransport.Handler, subscribeHandler moqtransport.SubscribeHandler) ( 57 | serverSession *moqtransport.Session, 58 | clientSession *moqtransport.Session, 59 | cancel func(), 60 | ) { 61 | serverSession = &moqtransport.Session{ 62 | Handler: handler, 63 | SubscribeHandler: subscribeHandler, 64 | InitialMaxRequestID: 100, 65 | Qlogger: nil, 66 | } 67 | var wg sync.WaitGroup 68 | wg.Add(1) 69 | go func() { 70 | defer wg.Done() 71 | err := serverSession.Run(quicmoq.NewServer(sConn)) 72 | assert.NoError(t, err) 73 | }() 74 | 75 | clientSession = &moqtransport.Session{ 76 | Handler: handler, 77 | SubscribeHandler: subscribeHandler, 78 | InitialMaxRequestID: 100, 79 | Qlogger: nil, 80 | } 81 | 82 | wg.Add(1) 83 | go func() { 84 | defer wg.Done() 85 | err := clientSession.Run(quicmoq.NewClient(cConn)) 86 | assert.NoError(t, err) 87 | }() 88 | 89 | cancel = func() { 90 | serverSession.Close() 91 | clientSession.Close() 92 | } 93 | wg.Wait() 94 | return 95 | } 96 | 97 | // Setup a bare-bones TLS config for the server 98 | func generateTLSConfig() (*tls.Config, error) { 99 | key, err := rsa.GenerateKey(rand.Reader, 1024) 100 | if err != nil { 101 | return nil, err 102 | } 103 | template := x509.Certificate{SerialNumber: big.NewInt(1)} 104 | certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) 105 | if err != nil { 106 | return nil, err 107 | } 108 | keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) 109 | certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) 110 | 111 | tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) 112 | if err != nil { 113 | return nil, err 114 | } 115 | return &tls.Config{ 116 | Certificates: []tls.Certificate{tlsCert}, 117 | NextProtos: []string{"moq-00", "h3"}, 118 | }, nil 119 | } 120 | -------------------------------------------------------------------------------- /internal/wire/client_setup_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestClientSetupMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | csm ClientSetupMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | csm: ClientSetupMessage{ 19 | SupportedVersions: nil, 20 | SetupParameters: nil, 21 | }, 22 | buf: []byte{}, 23 | expect: []byte{ 24 | 0x00, 0x00, 25 | }, 26 | }, 27 | { 28 | csm: ClientSetupMessage{ 29 | SupportedVersions: []Version{Draft_ietf_moq_transport_00}, 30 | SetupParameters: KVPList{}, 31 | }, 32 | buf: []byte{}, 33 | expect: []byte{ 34 | 0x01, 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 35 | }, 36 | }, 37 | { 38 | csm: ClientSetupMessage{ 39 | SupportedVersions: []Version{Draft_ietf_moq_transport_00}, 40 | SetupParameters: KVPList{ 41 | KeyValuePair{ 42 | Type: PathParameterKey, 43 | ValueBytes: []byte("A"), 44 | }, 45 | }, 46 | }, 47 | buf: []byte{}, 48 | expect: []byte{ 49 | 0x01, 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 'A', 50 | }, 51 | }, 52 | } 53 | for i, tc := range cases { 54 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 55 | res := tc.csm.Append(tc.buf) 56 | assert.Equal(t, tc.expect, res) 57 | }) 58 | } 59 | } 60 | 61 | func TestParseClientSetupMessage(t *testing.T) { 62 | cases := []struct { 63 | data []byte 64 | expect *ClientSetupMessage 65 | err error 66 | }{ 67 | { 68 | data: nil, 69 | expect: &ClientSetupMessage{}, 70 | err: io.EOF, 71 | }, 72 | { 73 | data: []byte{}, 74 | expect: &ClientSetupMessage{}, 75 | err: io.EOF, 76 | }, 77 | { 78 | data: []byte{ 79 | 0x01, 0x00, 80 | }, 81 | expect: &ClientSetupMessage{ 82 | SupportedVersions: []Version{0x00}, 83 | SetupParameters: KVPList{}, 84 | }, 85 | err: io.EOF, 86 | }, 87 | { 88 | data: []byte{ 89 | 0x01, 90 | }, 91 | expect: &ClientSetupMessage{}, 92 | err: io.EOF, 93 | }, 94 | { 95 | data: []byte{ 96 | 0x02, 0x00, 0x00 + 1, 97 | }, 98 | expect: &ClientSetupMessage{ 99 | SupportedVersions: []Version{0x00, 0x01}, 100 | SetupParameters: KVPList{}, 101 | }, 102 | err: io.EOF, 103 | }, 104 | { 105 | data: []byte{ 106 | 0x02, 0x00, 0x00 + 1, 0x00, 107 | }, 108 | expect: &ClientSetupMessage{ 109 | SupportedVersions: []Version{0, 0 + 1}, 110 | SetupParameters: KVPList{}, 111 | }, 112 | err: nil, 113 | }, 114 | { 115 | data: []byte{ 116 | 0x01, 0x00, 0x00, 117 | }, 118 | expect: &ClientSetupMessage{ 119 | SupportedVersions: []Version{0}, 120 | SetupParameters: KVPList{}, 121 | }, 122 | err: nil, 123 | }, 124 | { 125 | data: []byte{ 126 | 0x01, 0x00, 127 | }, 128 | expect: &ClientSetupMessage{ 129 | SupportedVersions: []Version{0x00}, 130 | SetupParameters: KVPList{}, 131 | }, 132 | err: io.EOF, 133 | }, 134 | { 135 | data: []byte{ 136 | 0x01, 0x00, 137 | 0x02, 0x02, 0x02, 0x01, 0x01, 'a', 138 | }, 139 | expect: &ClientSetupMessage{ 140 | SupportedVersions: []Version{0x00}, 141 | SetupParameters: KVPList{ 142 | KeyValuePair{ 143 | Type: MaxRequestIDParameterKey, 144 | ValueVarInt: 2, 145 | }, 146 | KeyValuePair{ 147 | Type: 1, 148 | ValueBytes: []byte("a"), 149 | }, 150 | }, 151 | }, 152 | err: nil, 153 | }, 154 | } 155 | for i, tc := range cases { 156 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 157 | res := &ClientSetupMessage{} 158 | err := res.parse(CurrentVersion, tc.data) 159 | assert.Equal(t, tc.expect, res) 160 | if tc.err != nil { 161 | assert.Equal(t, tc.err, err) 162 | } else { 163 | assert.NoError(t, err) 164 | } 165 | }) 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= 5 | github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= 6 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 7 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 8 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 9 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 10 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 11 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 12 | github.com/mengelbart/qlog v0.1.0 h1:8cDMuCMcKtzkPXUU5FF7OBwqKiy+De0GKvIvNawifoA= 13 | github.com/mengelbart/qlog v0.1.0/go.mod h1:nIlGcUugkfDu41B8LKdAwjHQ1NxAF54D9hS2EDOlVyk= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= 17 | github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= 18 | github.com/quic-go/quic-go v0.53.0 h1:QHX46sISpG2S03dPeZBgVIZp8dGagIaiu2FiVYvpCZI= 19 | github.com/quic-go/quic-go v0.53.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= 20 | github.com/quic-go/webtransport-go v0.9.0 h1:jgys+7/wm6JarGDrW+lD/r9BGqBAmqY/ssklE09bA70= 21 | github.com/quic-go/webtransport-go v0.9.0/go.mod h1:4FUYIiUc75XSsF6HShcLeXXYZJ9AGwo/xh3L8M/P1ao= 22 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 23 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= 24 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 25 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 26 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 27 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 28 | go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= 29 | go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= 30 | golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= 31 | golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= 32 | golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= 33 | golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 34 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 35 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 36 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 37 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 38 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 39 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 40 | golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= 41 | golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= 42 | golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= 43 | golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= 44 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 45 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 46 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 47 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 48 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 49 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | ) 8 | 9 | // Protocol is a transport protocol supported by MoQ. 10 | type Protocol int 11 | 12 | // The supported protocols 13 | const ( 14 | ProtocolQUIC Protocol = iota 15 | ProtocolWebTransport 16 | ) 17 | 18 | func (p Protocol) String() string { 19 | switch p { 20 | case ProtocolQUIC: 21 | return "quic" 22 | case ProtocolWebTransport: 23 | return "webtransport" 24 | default: 25 | return "invalid protocol" 26 | } 27 | } 28 | 29 | // Perspective indicates whether the connection is a client or a server 30 | type Perspective int 31 | 32 | // The perspectives 33 | const ( 34 | PerspectiveClient Perspective = iota 35 | PerspectiveServer 36 | ) 37 | 38 | func (p Perspective) String() string { 39 | switch p { 40 | case PerspectiveServer: 41 | return "server" 42 | case PerspectiveClient: 43 | return "client" 44 | default: 45 | return "invalid perspective" 46 | } 47 | } 48 | 49 | // A Stream is the interface implemented by bidirectional streams. 50 | type Stream interface { 51 | ReceiveStream 52 | SendStream 53 | } 54 | 55 | // ReceiveStream is the interface implemented by the receiving end of unidirectional 56 | // streams. 57 | type ReceiveStream interface { 58 | // Read reads from the stream. 59 | io.Reader 60 | 61 | // Stop stops reading from the stream and sends a signal to the sender to 62 | // stop sending on the stream. 63 | Stop(uint32) 64 | 65 | // StreamID returns the ID of the stream 66 | StreamID() uint64 67 | } 68 | 69 | // SendStream is the interface implemented by the sending end of unidirectional 70 | // streams. 71 | type SendStream interface { 72 | // Write writes to the stream. 73 | // Close closes the stream and guarantees retransmissions until all data has 74 | // been received by the receiver or the stream is reset. 75 | io.WriteCloser 76 | 77 | // Reset closes the stream and stops retransmitting outstanding data. 78 | Reset(uint32) 79 | 80 | // StreamID returns the ID of the stream 81 | StreamID() uint64 82 | } 83 | 84 | var ErrDatagramSupportDisabled = errors.New("datagram support disabled") 85 | 86 | // Connection is the interface of a QUIC/WebTransport connection. New Transports 87 | // expect an implementation of this interface as the underlying connection. 88 | // Implementations based on quic-go and webtransport-go are provided in quicmoq 89 | // and webTransportmoq. 90 | type Connection interface { 91 | // AcceptStream returns the next stream opened by the peer, blocking until 92 | // one is available. 93 | AcceptStream(context.Context) (Stream, error) 94 | 95 | // AcceptUniStream returns the next unidirectional stream opened by the 96 | // peer, blocking until one is available. 97 | AcceptUniStream(context.Context) (ReceiveStream, error) 98 | 99 | // OpenStream opens a new bidirectional stream. 100 | OpenStream() (Stream, error) 101 | 102 | // OpenStreamSync opens a new bidirectional stream, blocking until it can be 103 | // opened. 104 | OpenStreamSync(context.Context) (Stream, error) 105 | 106 | // OpenUniStream opens a new unidirectional stream. 107 | OpenUniStream() (SendStream, error) 108 | 109 | // OpenUniStream opens a new unidirectional stream, blocking until it can be 110 | // opened. 111 | OpenUniStreamSync(context.Context) (SendStream, error) 112 | 113 | // SendDatagram sends a datagram. 114 | SendDatagram([]byte) error 115 | 116 | // ReceiveDatagram receives the next datagram, blocking until one is 117 | // available. 118 | ReceiveDatagram(context.Context) ([]byte, error) 119 | 120 | // CloseWithError closes the connection with an error code and a reason 121 | // string. 122 | CloseWithError(uint64, string) error 123 | 124 | // Context returns a context that will be cancelled when the connection is 125 | // closed. 126 | Context() context.Context 127 | 128 | // Protocol returns the underlying Protocol of the connection. 129 | Protocol() Protocol 130 | 131 | // Perspective returns the perspective of the connection. 132 | Perspective() Perspective 133 | } 134 | -------------------------------------------------------------------------------- /internal/wire/control_message_type.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "io" 5 | "log/slog" 6 | ) 7 | 8 | type controlMessageType uint64 9 | 10 | // Control message types 11 | const ( 12 | messageTypeClientSetup controlMessageType = 0x20 13 | messageTypeServerSetup controlMessageType = 0x21 14 | 15 | messageTypeGoAway controlMessageType = 0x10 16 | 17 | messageTypeMaxRequestID controlMessageType = 0x15 18 | messageTypeRequestsBlocked controlMessageType = 0x1a 19 | 20 | messageTypeSubscribe controlMessageType = 0x03 21 | messageTypeSubscribeOk controlMessageType = 0x04 22 | messageTypeSubscribeError controlMessageType = 0x05 23 | messageTypeSubscribeUpdate controlMessageType = 0x02 24 | messageTypeUnsubscribe controlMessageType = 0x0a 25 | messageTypeSubscribeDone controlMessageType = 0x0b 26 | 27 | messageTypePublish controlMessageType = 0x1d 28 | messageTypePublishOk controlMessageType = 0x1e 29 | messageTypePublishError controlMessageType = 0x1f 30 | 31 | messageTypeFetch controlMessageType = 0x16 32 | messageTypeFetchOk controlMessageType = 0x18 33 | messageTypeFetchError controlMessageType = 0x19 34 | messageTypeFetchCancel controlMessageType = 0x17 35 | 36 | messageTypeTrackStatus controlMessageType = 0x0d 37 | messageTypeTrackStatusOk controlMessageType = 0x0e 38 | messageTypeTrackStatusError controlMessageType = 0x0f 39 | 40 | messageTypeAnnounce controlMessageType = 0x06 41 | messageTypeAnnounceOk controlMessageType = 0x07 42 | messageTypeAnnounceError controlMessageType = 0x08 43 | messageTypeUnannounce controlMessageType = 0x09 44 | messageTypeAnnounceCancel controlMessageType = 0x0c 45 | 46 | messageTypeSubscribeNamespace controlMessageType = 0x11 47 | messageTypeSubscribeNamespaceOk controlMessageType = 0x12 48 | messageTypeSubscribeNamespaceError controlMessageType = 0x13 49 | messageTypeUnsubscribeNamespace controlMessageType = 0x14 50 | ) 51 | 52 | func (mt controlMessageType) String() string { 53 | switch mt { 54 | case messageTypeClientSetup: 55 | return "ClientSetup" 56 | case messageTypeServerSetup: 57 | return "ServerSetup" 58 | 59 | case messageTypeGoAway: 60 | return "GoAway" 61 | 62 | case messageTypeMaxRequestID: 63 | return "MaxRequestID" 64 | case messageTypeRequestsBlocked: 65 | return "RequestsBlocked" 66 | 67 | case messageTypeSubscribe: 68 | return "Subscribe" 69 | case messageTypeSubscribeOk: 70 | return "SubscribeOk" 71 | case messageTypeSubscribeError: 72 | return "SubscribeError" 73 | case messageTypeUnsubscribe: 74 | return "Unsubscribe" 75 | case messageTypeSubscribeUpdate: 76 | return "SubscribeUpdate" 77 | case messageTypeSubscribeDone: 78 | return "SubscribeDone" 79 | 80 | case messageTypePublish: 81 | return "Publish" 82 | case messageTypePublishOk: 83 | return "PublishOk" 84 | case messageTypePublishError: 85 | return "PublishError" 86 | 87 | case messageTypeFetch: 88 | return "Fetch" 89 | case messageTypeFetchOk: 90 | return "FetchOk" 91 | case messageTypeFetchError: 92 | return "FetchError" 93 | case messageTypeFetchCancel: 94 | return "FetchCancel" 95 | 96 | case messageTypeTrackStatus: 97 | return "TrackStatus" 98 | case messageTypeTrackStatusOk: 99 | return "TrackStatusOk" 100 | case messageTypeTrackStatusError: 101 | return "TrackStatusError" 102 | 103 | case messageTypeAnnounce: 104 | return "Announce" 105 | case messageTypeAnnounceOk: 106 | return "AnnounceOk" 107 | case messageTypeAnnounceError: 108 | return "AnnounceError" 109 | case messageTypeUnannounce: 110 | return "Unannounce" 111 | case messageTypeAnnounceCancel: 112 | return "AnnounceCancel" 113 | 114 | case messageTypeSubscribeNamespace: 115 | return "SubscribeNamespace" 116 | case messageTypeSubscribeNamespaceOk: 117 | return "SubscribeNamespaceOk" 118 | case messageTypeSubscribeNamespaceError: 119 | return "SubscribeNamespaceError" 120 | case messageTypeUnsubscribeNamespace: 121 | return "UnsubscribeNamespace" 122 | } 123 | return "unknown message type" 124 | } 125 | 126 | type messageReader interface { 127 | io.Reader 128 | io.ByteReader 129 | Discard(int) (int, error) 130 | } 131 | 132 | type Message interface { 133 | Append([]byte) []byte 134 | parse(Version, []byte) error 135 | } 136 | 137 | type ControlMessage interface { 138 | Message 139 | Type() controlMessageType 140 | slog.LogValuer 141 | } 142 | -------------------------------------------------------------------------------- /internal/wire/subscribe_done_message_test.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSubscribeDoneMessageAppend(t *testing.T) { 12 | cases := []struct { 13 | srm SubscribeDoneMessage 14 | buf []byte 15 | expect []byte 16 | }{ 17 | { 18 | srm: SubscribeDoneMessage{ 19 | RequestID: 0, 20 | StatusCode: 0, 21 | StreamCount: 0, 22 | ReasonPhrase: "", 23 | }, 24 | buf: []byte{}, 25 | expect: []byte{0x00, 0x00, 0x00, 0x00}, 26 | }, 27 | { 28 | srm: SubscribeDoneMessage{ 29 | RequestID: 0, 30 | StatusCode: 1, 31 | StreamCount: 2, 32 | ReasonPhrase: "reason", 33 | }, 34 | buf: []byte{}, 35 | expect: []byte{ 36 | 0x00, 37 | 0x01, 38 | 0x02, 39 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 40 | }, 41 | }, 42 | { 43 | srm: SubscribeDoneMessage{ 44 | RequestID: 17, 45 | StatusCode: 1, 46 | StreamCount: 4, 47 | ReasonPhrase: "reason", 48 | }, 49 | buf: []byte{0x0a, 0x0b, 0x0c, 0x0d}, 50 | expect: []byte{ 51 | 0x0a, 0x0b, 0x0c, 0x0d, 52 | 0x11, 53 | 0x01, 54 | 0x04, 55 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 56 | }, 57 | }, 58 | { 59 | srm: SubscribeDoneMessage{ 60 | RequestID: 0, 61 | StatusCode: 0, 62 | StreamCount: 0, 63 | ReasonPhrase: "", 64 | }, 65 | buf: []byte{}, 66 | expect: []byte{0x00, 0x00, 0x00, 0x00}, 67 | }, 68 | { 69 | srm: SubscribeDoneMessage{ 70 | RequestID: 0, 71 | StatusCode: 1, 72 | StreamCount: 2, 73 | ReasonPhrase: "reason", 74 | }, 75 | buf: []byte{}, 76 | expect: []byte{ 77 | 0x00, 78 | 0x01, 79 | 0x02, 80 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 81 | }, 82 | }, 83 | { 84 | srm: SubscribeDoneMessage{ 85 | RequestID: 17, 86 | StatusCode: 1, 87 | StreamCount: 2, 88 | ReasonPhrase: "reason", 89 | }, 90 | buf: []byte{0x0a, 0x0b, 0x0c, 0x0d}, 91 | expect: []byte{ 92 | 0x0a, 0x0b, 0x0c, 0x0d, 93 | 0x11, 94 | 0x01, 95 | 0x02, 96 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 97 | }, 98 | }, 99 | } 100 | for i, tc := range cases { 101 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 102 | res := tc.srm.Append(tc.buf) 103 | assert.Equal(t, tc.expect, res) 104 | }) 105 | } 106 | } 107 | 108 | func TestParseSubscribeDoneMessage(t *testing.T) { 109 | cases := []struct { 110 | data []byte 111 | expect *SubscribeDoneMessage 112 | err error 113 | }{ 114 | { 115 | data: nil, 116 | expect: &SubscribeDoneMessage{}, 117 | err: io.EOF, 118 | }, 119 | { 120 | data: []byte{}, 121 | expect: &SubscribeDoneMessage{}, 122 | err: io.EOF, 123 | }, 124 | { 125 | data: []byte{ 126 | 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 127 | }, 128 | expect: &SubscribeDoneMessage{ 129 | RequestID: 0, 130 | StatusCode: 0, 131 | StreamCount: 0, 132 | ReasonPhrase: "", 133 | }, 134 | err: nil, 135 | }, 136 | { 137 | data: []byte{ 138 | 0x00, 139 | 0x01, 140 | 0x02, 141 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 142 | 0x01, 143 | 0x02, 144 | 0x03, 145 | }, 146 | expect: &SubscribeDoneMessage{ 147 | RequestID: 0, 148 | StatusCode: 1, 149 | StreamCount: 2, 150 | ReasonPhrase: "reason", 151 | }, 152 | err: nil, 153 | }, 154 | { 155 | data: []byte{ 156 | 0x00, 157 | 0x01, 158 | 0x02, 159 | 0x06, 'r', 'e', 'a', 's', 'o', 'n', 160 | 0x00, 161 | }, 162 | expect: &SubscribeDoneMessage{ 163 | RequestID: 0, 164 | StatusCode: 1, 165 | StreamCount: 2, 166 | ReasonPhrase: "reason", 167 | }, 168 | err: nil, 169 | }, 170 | { 171 | data: []byte{ 172 | 0x00, 0x00, 0x00, 0x00, 173 | }, 174 | expect: &SubscribeDoneMessage{ 175 | RequestID: 0, 176 | StatusCode: 0, 177 | StreamCount: 0, 178 | ReasonPhrase: "", 179 | }, 180 | err: nil, 181 | }, 182 | } 183 | for i, tc := range cases { 184 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 185 | res := &SubscribeDoneMessage{} 186 | err := res.parse(CurrentVersion, tc.data) 187 | assert.Equal(t, tc.expect, res) 188 | if tc.err != nil { 189 | assert.Equal(t, tc.err, err) 190 | } else { 191 | assert.NoError(t, err) 192 | } 193 | }) 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /mock_control_message_stream_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/mengelbart/moqtransport (interfaces: ControlMessageStream) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -build_flags=-tags=gomock -typed -package moqtransport -write_package_comment=false -self_package github.com/mengelbart/moqtransport -destination mock_control_message_stream_test.go github.com/mengelbart/moqtransport ControlMessageStream 7 | // 8 | 9 | package moqtransport 10 | 11 | import ( 12 | iter "iter" 13 | reflect "reflect" 14 | 15 | wire "github.com/mengelbart/moqtransport/internal/wire" 16 | gomock "go.uber.org/mock/gomock" 17 | ) 18 | 19 | // MockControlMessageStream is a mock of ControlMessageStream interface. 20 | type MockControlMessageStream struct { 21 | ctrl *gomock.Controller 22 | recorder *MockControlMessageStreamMockRecorder 23 | isgomock struct{} 24 | } 25 | 26 | // MockControlMessageStreamMockRecorder is the mock recorder for MockControlMessageStream. 27 | type MockControlMessageStreamMockRecorder struct { 28 | mock *MockControlMessageStream 29 | } 30 | 31 | // NewMockControlMessageStream creates a new mock instance. 32 | func NewMockControlMessageStream(ctrl *gomock.Controller) *MockControlMessageStream { 33 | mock := &MockControlMessageStream{ctrl: ctrl} 34 | mock.recorder = &MockControlMessageStreamMockRecorder{mock} 35 | return mock 36 | } 37 | 38 | // EXPECT returns an object that allows the caller to indicate expected use. 39 | func (m *MockControlMessageStream) EXPECT() *MockControlMessageStreamMockRecorder { 40 | return m.recorder 41 | } 42 | 43 | // read mocks base method. 44 | func (m *MockControlMessageStream) read() iter.Seq2[wire.ControlMessage, error] { 45 | m.ctrl.T.Helper() 46 | ret := m.ctrl.Call(m, "read") 47 | ret0, _ := ret[0].(iter.Seq2[wire.ControlMessage, error]) 48 | return ret0 49 | } 50 | 51 | // read indicates an expected call of read. 52 | func (mr *MockControlMessageStreamMockRecorder) read() *MockControlMessageStreamreadCall { 53 | mr.mock.ctrl.T.Helper() 54 | call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "read", reflect.TypeOf((*MockControlMessageStream)(nil).read)) 55 | return &MockControlMessageStreamreadCall{Call: call} 56 | } 57 | 58 | // MockControlMessageStreamreadCall wrap *gomock.Call 59 | type MockControlMessageStreamreadCall struct { 60 | *gomock.Call 61 | } 62 | 63 | // Return rewrite *gomock.Call.Return 64 | func (c *MockControlMessageStreamreadCall) Return(arg0 iter.Seq2[wire.ControlMessage, error]) *MockControlMessageStreamreadCall { 65 | c.Call = c.Call.Return(arg0) 66 | return c 67 | } 68 | 69 | // Do rewrite *gomock.Call.Do 70 | func (c *MockControlMessageStreamreadCall) Do(f func() iter.Seq2[wire.ControlMessage, error]) *MockControlMessageStreamreadCall { 71 | c.Call = c.Call.Do(f) 72 | return c 73 | } 74 | 75 | // DoAndReturn rewrite *gomock.Call.DoAndReturn 76 | func (c *MockControlMessageStreamreadCall) DoAndReturn(f func() iter.Seq2[wire.ControlMessage, error]) *MockControlMessageStreamreadCall { 77 | c.Call = c.Call.DoAndReturn(f) 78 | return c 79 | } 80 | 81 | // write mocks base method. 82 | func (m *MockControlMessageStream) write(arg0 wire.ControlMessage) error { 83 | m.ctrl.T.Helper() 84 | ret := m.ctrl.Call(m, "write", arg0) 85 | ret0, _ := ret[0].(error) 86 | return ret0 87 | } 88 | 89 | // write indicates an expected call of write. 90 | func (mr *MockControlMessageStreamMockRecorder) write(arg0 any) *MockControlMessageStreamwriteCall { 91 | mr.mock.ctrl.T.Helper() 92 | call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "write", reflect.TypeOf((*MockControlMessageStream)(nil).write), arg0) 93 | return &MockControlMessageStreamwriteCall{Call: call} 94 | } 95 | 96 | // MockControlMessageStreamwriteCall wrap *gomock.Call 97 | type MockControlMessageStreamwriteCall struct { 98 | *gomock.Call 99 | } 100 | 101 | // Return rewrite *gomock.Call.Return 102 | func (c *MockControlMessageStreamwriteCall) Return(arg0 error) *MockControlMessageStreamwriteCall { 103 | c.Call = c.Call.Return(arg0) 104 | return c 105 | } 106 | 107 | // Do rewrite *gomock.Call.Do 108 | func (c *MockControlMessageStreamwriteCall) Do(f func(wire.ControlMessage) error) *MockControlMessageStreamwriteCall { 109 | c.Call = c.Call.Do(f) 110 | return c 111 | } 112 | 113 | // DoAndReturn rewrite *gomock.Call.DoAndReturn 114 | func (c *MockControlMessageStreamwriteCall) DoAndReturn(f func(wire.ControlMessage) error) *MockControlMessageStreamwriteCall { 115 | c.Call = c.Call.DoAndReturn(f) 116 | return c 117 | } 118 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | // Common Message types. Handlers can react to any of these messages. 4 | const ( 5 | MessageSubscribe = "SUBSCRIBE" 6 | MessageFetch = "FETCH" 7 | MessageAnnounce = "ANNOUNCE" 8 | MessageAnnounceCancel = "ANNOUNCE_CANCEL" 9 | MessageUnannounce = "UNANNOUNCE" 10 | MessageTrackStatusRequest = "TRACK_STATUS_REQUEST" 11 | MessageTrackStatus = "TRACK_STATUS" 12 | MessageGoAway = "GO_AWAY" 13 | MessageSubscribeAnnounces = "SUBSCRIBE_ANNOUNCES" 14 | MessageUnsubscribeAnnounces = "UNSUBSCRIBE_ANNOUNCES" 15 | ) 16 | 17 | // Message represents a message from the peer that can be handled by the 18 | // application. 19 | type Message struct { 20 | // Method describes the type of the message. 21 | Method string 22 | 23 | // RequestID is set if the message references a request. 24 | RequestID uint64 25 | 26 | // Namespace is set if the message references a namespace. 27 | Namespace []string 28 | // Track is set if the message references a track. 29 | Track string 30 | 31 | // Authorization 32 | Authorization string 33 | 34 | // NewSessionURI is set in a GoAway message and points to a URI that can be 35 | // used to setup a new session before closing the current session. 36 | NewSessionURI string 37 | 38 | // ErrorCode is set if the message is an error message. 39 | ErrorCode uint64 40 | // ReasonPhrase is set if the message is an error message. 41 | ReasonPhrase string 42 | } 43 | 44 | // ResponseWriter can be used to respond to messages that expect a response. 45 | type ResponseWriter interface { 46 | // Accept sends an affirmative response to a message. 47 | Accept() error 48 | 49 | // Reject sends a negative response to a message. 50 | Reject(code uint64, reason string) error 51 | } 52 | 53 | // Publisher is the interface implemented by SubscribeResponseWriters 54 | type Publisher interface { 55 | // SendDatagram sends an object in a datagram. 56 | SendDatagram(Object) error 57 | 58 | // OpenSubgroup opens and returns a new subgroup. 59 | OpenSubgroup(groupID, subgroupID uint64, priority uint8) (*Subgroup, error) 60 | 61 | // CloseWithError closes the track and sends SUBSCRIBE_DONE with code and 62 | // reason. 63 | CloseWithError(code uint64, reason string) error 64 | } 65 | 66 | // FetchPublisher is the interface implemented by ResponseWriters of Fetch 67 | // messages. 68 | type FetchPublisher interface { 69 | // OpenFetchStream opens and returns a new fetch stream. 70 | FetchStream() (*FetchStream, error) 71 | } 72 | 73 | // StatusRequestHandler is the interface implemented by ResponseWriters of 74 | // TrackStatusRequest messages. The first call to Accept sends the response. 75 | // Calling Reject sets the status to "track does not exist" and then calls 76 | // Accept. Reject ignores the errorCode and reasonPhrase. Applications are 77 | // responsible for following the ruls of track status messages. 78 | type StatusRequestHandler interface { 79 | // SetStatus sets the status for the response. Call this before calling 80 | // Accept. 81 | SetStatus(statusCode, lastGroupID, lastObjectID uint64) 82 | } 83 | 84 | // Handler is the handler interface for non-specific MoQ messages. 85 | type Handler interface { 86 | Handle(ResponseWriter, *Message) 87 | } 88 | 89 | // HandlerFunc is a type that implements Handler. 90 | type HandlerFunc func(ResponseWriter, *Message) 91 | 92 | // Handle implements Handler. 93 | func (f HandlerFunc) Handle(rw ResponseWriter, r *Message) { 94 | f(rw, r) 95 | } 96 | 97 | // SubcribeHandler is the handler interface for handling SUBSCRIBE messages. 98 | type SubscribeHandler interface { 99 | HandleSubscribe(*SubscribeResponseWriter, *SubscribeMessage) 100 | } 101 | 102 | // SubscribeHandlerFunc is a type that implements SubscribeHandler. 103 | type SubscribeHandlerFunc func(*SubscribeResponseWriter, *SubscribeMessage) 104 | 105 | // HandleSubscribe implements SubscribeHandler. 106 | func (f SubscribeHandlerFunc) HandleSubscribe(rw *SubscribeResponseWriter, m *SubscribeMessage) { 107 | f(rw, m) 108 | } 109 | 110 | // SubscribeUpdateHandler is the handler interface for handling SUBSCRIBE_UPDATE messages. 111 | type SubscribeUpdateHandler interface { 112 | HandleSubscribeUpdate(*SubscribeUpdateMessage) 113 | } 114 | 115 | // SubscribeUpdateHandlerFunc is a type that implements SubscribeUpdateHandler. 116 | type SubscribeUpdateHandlerFunc func(*SubscribeUpdateMessage) 117 | 118 | // HandleSubscribeUpdate implements SubscribeUpdateHandler. 119 | func (f SubscribeUpdateHandlerFunc) HandleSubscribeUpdate(m *SubscribeUpdateMessage) { 120 | f(m) 121 | } 122 | -------------------------------------------------------------------------------- /local_track.go: -------------------------------------------------------------------------------- 1 | package moqtransport 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | 8 | "github.com/mengelbart/moqtransport/internal/slices" 9 | "github.com/mengelbart/moqtransport/internal/wire" 10 | "github.com/mengelbart/qlog" 11 | "github.com/mengelbart/qlog/moqt" 12 | ) 13 | 14 | var ( 15 | ErrUnsusbcribed = errors.New("track closed, peer unsubscribed") 16 | ErrSubscriptionDone = errors.New("track closed, subscription done") 17 | ) 18 | 19 | type subscribeDoneCallback func(code, count uint64, reason string) error 20 | 21 | type localTrack struct { 22 | qlogger *qlog.Logger 23 | 24 | conn Connection 25 | requestID uint64 26 | trackAlias uint64 27 | subgroupCount uint64 28 | fetchStreamLock sync.Mutex 29 | fetchStream *FetchStream 30 | ctx context.Context 31 | cancelCtx context.CancelCauseFunc 32 | subscribeDone subscribeDoneCallback 33 | } 34 | 35 | func newLocalTrack(conn Connection, requestID, trackAlias uint64, onSubscribeDone subscribeDoneCallback, qlogger *qlog.Logger) *localTrack { 36 | ctx, cancel := context.WithCancelCause(context.Background()) 37 | lt := &localTrack{ 38 | qlogger: qlogger, 39 | conn: conn, 40 | requestID: requestID, 41 | trackAlias: trackAlias, 42 | subgroupCount: 0, 43 | fetchStreamLock: sync.Mutex{}, 44 | fetchStream: nil, 45 | ctx: ctx, 46 | cancelCtx: cancel, 47 | subscribeDone: onSubscribeDone, 48 | } 49 | return lt 50 | } 51 | 52 | func (p *localTrack) getFetchStream() (*FetchStream, error) { 53 | p.fetchStreamLock.Lock() 54 | defer p.fetchStreamLock.Unlock() 55 | 56 | if err := p.closed(); err != nil { 57 | return nil, err 58 | } 59 | if p.fetchStream != nil { 60 | return p.fetchStream, nil 61 | } 62 | stream, err := p.conn.OpenUniStream() 63 | if err != nil { 64 | return nil, err 65 | } 66 | p.fetchStream, err = newFetchStream(stream, p.requestID, p.qlogger) 67 | if err != nil { 68 | return nil, err 69 | } 70 | return p.fetchStream, nil 71 | } 72 | 73 | func (p *localTrack) sendDatagram(o Object) error { 74 | if err := p.closed(); err != nil { 75 | return err 76 | } 77 | om := &wire.ObjectDatagramMessage{ 78 | TrackAlias: p.trackAlias, 79 | GroupID: o.GroupID, 80 | ObjectID: o.ObjectID, 81 | PublisherPriority: 0, 82 | ObjectExtensionHeaders: nil, 83 | ObjectStatus: 0, 84 | ObjectPayload: o.Payload, 85 | } 86 | var buf []byte 87 | buf = om.AppendDatagram(buf) 88 | if p.qlogger != nil { 89 | eth := slices.Collect(slices.Map( 90 | om.ObjectExtensionHeaders, 91 | func(e wire.KeyValuePair) moqt.ExtensionHeader { 92 | return moqt.ExtensionHeader{ 93 | HeaderType: 0, // TODO 94 | HeaderValue: 0, // TODO 95 | HeaderLength: 0, // TODO 96 | Payload: qlog.RawInfo{}, 97 | } 98 | }), 99 | ) 100 | name := moqt.ObjectDatagramEventCreated 101 | if len(om.ObjectPayload) > 0 { 102 | name = moqt.ObjectDatagramStatusEventCreated 103 | } 104 | p.qlogger.Log(moqt.ObjectDatagramEvent{ 105 | EventName: name, 106 | TrackAlias: om.TrackAlias, 107 | GroupID: om.GroupID, 108 | ObjectID: om.ObjectID, 109 | PublisherPriority: om.PublisherPriority, 110 | ExtensionHeadersLength: uint64(len(om.ObjectExtensionHeaders)), 111 | ExtensionHeaders: eth, 112 | ObjectStatus: uint64(om.ObjectStatus), 113 | Payload: qlog.RawInfo{ 114 | Length: uint64(len(om.ObjectPayload)), 115 | PayloadLength: uint64(len(om.ObjectPayload)), 116 | Data: om.ObjectPayload, 117 | }, 118 | }) 119 | } 120 | return p.conn.SendDatagram(buf) 121 | } 122 | 123 | func (p *localTrack) openSubgroup(groupID, subgroupID uint64, priority uint8) (*Subgroup, error) { 124 | if err := p.closed(); err != nil { 125 | return nil, err 126 | } 127 | stream, err := p.conn.OpenUniStream() 128 | if err != nil { 129 | return nil, err 130 | } 131 | p.subgroupCount++ 132 | return newSubgroup(stream, p.trackAlias, groupID, subgroupID, priority, p.qlogger) 133 | } 134 | 135 | func (s *localTrack) close(code uint64, reason string) error { 136 | s.cancelCtx(ErrSubscriptionDone) 137 | if s.subscribeDone != nil { 138 | return s.subscribeDone(code, s.subgroupCount, reason) 139 | } 140 | return nil 141 | } 142 | 143 | func (s *localTrack) unsubscribe() { 144 | s.cancelCtx(ErrUnsusbcribed) 145 | } 146 | 147 | func (s *localTrack) closed() error { 148 | select { 149 | case <-s.ctx.Done(): 150 | return context.Cause(s.ctx) 151 | default: 152 | return nil 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /internal/wire/fetch_message.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/mengelbart/qlog" 7 | "github.com/quic-go/quic-go/quicvarint" 8 | ) 9 | 10 | const ( 11 | FetchTypeStandalone = 0x01 12 | FetchTypeRelativeJoining = 0x02 13 | FetchTypeAbsoluteJoining = 0x03 14 | ) 15 | 16 | // TODO: Add tests 17 | type FetchMessage struct { 18 | RequestID uint64 19 | SubscriberPriority uint8 20 | GroupOrder uint8 21 | FetchType uint64 22 | TrackNamespace Tuple 23 | TrackName []byte 24 | StartGroup uint64 25 | StartObject uint64 26 | EndGroup uint64 27 | EndObject uint64 28 | JoiningSubscribeID uint64 29 | JoiningStart uint64 30 | Parameters KVPList 31 | } 32 | 33 | // Attrs implements moqt.ControlMessage. 34 | func (m *FetchMessage) LogValue() slog.Value { 35 | attrs := []slog.Attr{ 36 | slog.String("type", "fetch"), 37 | slog.Uint64("request_id", m.RequestID), 38 | slog.Any("subscriber_priority", m.SubscriberPriority), 39 | slog.Any("group_order", m.GroupOrder), 40 | slog.Uint64("fetch_type", m.FetchType), 41 | } 42 | 43 | if m.FetchType == FetchTypeStandalone { 44 | attrs = append(attrs, 45 | slog.Any("track_namespace", m.TrackNamespace), 46 | slog.Any("track_name", qlog.RawInfo{ 47 | Length: uint64(len(m.TrackName)), 48 | PayloadLength: uint64(len(m.TrackName)), 49 | Data: m.TrackName, 50 | }), 51 | slog.Uint64("start_group", m.StartGroup), 52 | slog.Uint64("start_object", m.StartObject), 53 | slog.Uint64("end_group", m.EndGroup), 54 | slog.Uint64("end_object", m.EndObject), 55 | ) 56 | } 57 | if m.FetchType == FetchTypeAbsoluteJoining || m.FetchType == FetchTypeRelativeJoining { 58 | attrs = append(attrs, 59 | slog.Uint64("joining_subscribe_id", m.JoiningSubscribeID), 60 | slog.Uint64("preceding_group_offset", m.JoiningStart), 61 | ) 62 | } 63 | 64 | attrs = append(attrs, 65 | slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), 66 | ) 67 | 68 | if len(m.Parameters) > 0 { 69 | attrs = append(attrs, 70 | slog.Any("setup_parameters", m.Parameters), 71 | ) 72 | } 73 | return slog.GroupValue(attrs...) 74 | } 75 | 76 | func (m FetchMessage) Type() controlMessageType { 77 | return messageTypeFetch 78 | } 79 | 80 | func (m *FetchMessage) Append(buf []byte) []byte { 81 | buf = quicvarint.Append(buf, m.RequestID) 82 | buf = append(buf, m.SubscriberPriority) 83 | buf = append(buf, m.GroupOrder) 84 | buf = quicvarint.Append(buf, m.FetchType) 85 | 86 | if m.FetchType == FetchTypeStandalone { 87 | buf = m.TrackNamespace.append(buf) 88 | buf = appendVarIntBytes(buf, m.TrackName) 89 | buf = quicvarint.Append(buf, m.StartGroup) 90 | buf = quicvarint.Append(buf, m.StartObject) 91 | buf = quicvarint.Append(buf, m.EndGroup) 92 | buf = quicvarint.Append(buf, m.EndObject) 93 | } else { 94 | buf = quicvarint.Append(buf, m.JoiningSubscribeID) 95 | buf = quicvarint.Append(buf, m.JoiningStart) 96 | } 97 | 98 | return m.Parameters.appendNum(buf) 99 | } 100 | 101 | func (m *FetchMessage) parse(_ Version, data []byte) (err error) { 102 | var n int 103 | m.RequestID, n, err = quicvarint.Parse(data) 104 | if err != nil { 105 | return err 106 | } 107 | data = data[n:] 108 | 109 | if len(data) < 2 { 110 | return errLengthMismatch 111 | } 112 | m.SubscriberPriority = data[0] 113 | m.GroupOrder = data[1] 114 | if m.GroupOrder > 2 { 115 | return errInvalidGroupOrder 116 | } 117 | data = data[2:] 118 | 119 | m.FetchType, n, err = quicvarint.Parse(data) 120 | if err != nil { 121 | return err 122 | } 123 | data = data[n:] 124 | 125 | if m.FetchType < FetchTypeStandalone || m.FetchType > FetchTypeAbsoluteJoining { 126 | return errInvalidFetchType 127 | } 128 | 129 | if m.FetchType == FetchTypeStandalone { 130 | m.TrackNamespace, n, err = parseTuple(data) 131 | if err != nil { 132 | return err 133 | } 134 | data = data[n:] 135 | 136 | m.TrackName, n, err = parseVarIntBytes(data) 137 | if err != nil { 138 | return err 139 | } 140 | data = data[n:] 141 | 142 | m.StartGroup, n, err = quicvarint.Parse(data) 143 | if err != nil { 144 | return err 145 | } 146 | data = data[n:] 147 | 148 | m.StartObject, n, err = quicvarint.Parse(data) 149 | if err != nil { 150 | return err 151 | } 152 | data = data[n:] 153 | 154 | m.EndGroup, n, err = quicvarint.Parse(data) 155 | if err != nil { 156 | return err 157 | } 158 | data = data[n:] 159 | 160 | m.EndObject, n, err = quicvarint.Parse(data) 161 | if err != nil { 162 | return err 163 | } 164 | data = data[n:] 165 | } else { 166 | m.JoiningSubscribeID, n, err = quicvarint.Parse(data) 167 | if err != nil { 168 | return err 169 | } 170 | data = data[n:] 171 | 172 | m.JoiningStart, n, err = quicvarint.Parse(data) 173 | if err != nil { 174 | return err 175 | } 176 | data = data[n:] 177 | } 178 | 179 | m.Parameters = KVPList{} 180 | return m.Parameters.parseNum(data) 181 | } 182 | --------------------------------------------------------------------------------