├── .gitignore ├── go.mod ├── coordinator ├── tinmemory_test.go ├── coordinator.go ├── test_util.go ├── rmsg_serialise.go ├── tbad_netwk_test.go ├── trmsg_serialise_test.go ├── rmsg.go ├── inmemory.go └── readwrite.go ├── go.sum ├── bin ├── perf │ ├── tmain_test.go │ └── main.go ├── svr │ ├── client │ │ ├── data.go.disabled │ │ ├── ttrader_test.go.disabled │ │ ├── trader.go.disabled │ │ ├── server.go.disabled │ │ ├── balanceManager.go.disabled │ │ └── ttrader_unit_test.go.disabled │ ├── main.go.disabled │ └── html │ │ └── index.html └── itchdebug │ └── main.go ├── matcher ├── pqueue │ ├── slab.go │ ├── public_prioq.go │ ├── public_refprioq.go │ ├── refprioq.go │ ├── order.go │ ├── rbtree.go │ └── tprioq_test.go ├── tmatch_test.go ├── tcompare_test.go ├── tcoordinator_test.go ├── trefmatcher_test.go ├── matcher.go └── testsuite.go ├── q ├── meddlers.go └── meddle_q.go ├── msg ├── serialise.go ├── tserialise_test.go ├── msg.go ├── maker.go └── tmsg_test.go ├── itch └── reader.go └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | bin/perf/perf 2 | bin/itchdebug/itchdebug 3 | bin/svr/svr 4 | *.swp 5 | *.prof 6 | *.odat 7 | _test 8 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fmstephe/matching_engine 2 | 3 | go 1.18 4 | 5 | require github.com/fmstephe/flib v0.0.0-20170802081819-76e5765dde32 // indirect 6 | -------------------------------------------------------------------------------- /coordinator/tinmemory_test.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestGoodNetwork(t *testing.T) { 8 | testBadNetwork(t, 0.0, InMemory) 9 | } 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/fmstephe/flib v0.0.0-20170802081819-76e5765dde32 h1:oGgwjuc9Ftzc0Cnf3VtoCqIwW3874MqplZADiDKXrKw= 2 | github.com/fmstephe/flib v0.0.0-20170802081819-76e5765dde32/go.mod h1:Shzzxm47fvpk50Rbd41T2pSEme5oaocOw1kuYTA0xaM= 3 | -------------------------------------------------------------------------------- /bin/perf/tmain_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | // This test doesn't test anything other than that the perf tool runs and does not deadlock 8 | func TestPerf(t *testing.T) { 9 | doPerf(false) 10 | } 11 | -------------------------------------------------------------------------------- /coordinator/coordinator.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | type CoordinatorFunc func(reader io.ReadCloser, writer io.WriteCloser, app AppMsgRunner, originId uint32, name string, log bool) 8 | 9 | type AppMsgRunner interface { 10 | Config(name string, in MsgReader, out MsgWriter) 11 | Run() 12 | } 13 | 14 | type AppMsgHelper struct { 15 | Name string 16 | In MsgReader 17 | Out MsgWriter 18 | } 19 | 20 | func (a *AppMsgHelper) Config(name string, in MsgReader, out MsgWriter) { 21 | a.Name = name 22 | a.In = in 23 | a.Out = out 24 | } 25 | -------------------------------------------------------------------------------- /coordinator/test_util.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "math/rand" 5 | ) 6 | 7 | func randomUniqueMsgs() []*RMessage { 8 | uniqueMap := make(map[uint32]bool) 9 | r := rand.New(rand.NewSource(1)) 10 | msgs := make([]*RMessage, 0) 11 | for i := 0; i < 100; i++ { 12 | origin := uint32(r.Int31()) 13 | id := uint32(r.Int31()) 14 | setOnce(uniqueMap, origin, id) 15 | m := &RMessage{originId: origin, msgId: id} 16 | msgs = append(msgs, m) 17 | } 18 | return msgs 19 | } 20 | 21 | func setOnce(uniqueMap map[uint32]bool, origin, id uint32) { 22 | val := origin + id 23 | if uniqueMap[val] == true { 24 | panic("Generated non-unique message") 25 | } 26 | uniqueMap[val] = true 27 | } 28 | -------------------------------------------------------------------------------- /bin/svr/client/data.go.disabled: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | ) 6 | 7 | type receivedMessage struct { 8 | Accepted bool `json:"accepted"` 9 | Message msg.Message `json:"message"` 10 | } 11 | 12 | type Response struct { 13 | State traderState `json:"state"` 14 | Received receivedMessage `json:"received"` 15 | Comment string `json:"comment"` 16 | } 17 | 18 | type traderState struct { 19 | CurrentBalance uint64 `json:"currentBalance"` 20 | AvailableBalance uint64 `json:"availableBalance"` 21 | StocksHeld map[string]uint64 `json:"stocksHeld"` 22 | StocksToSell map[string]uint64 `json:"stocksToSell"` 23 | Outstanding []msg.Message `json:"outstanding"` 24 | } 25 | -------------------------------------------------------------------------------- /matcher/pqueue/slab.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import () 4 | 5 | type Slab struct { 6 | free *OrderNode 7 | orders []OrderNode 8 | } 9 | 10 | func NewSlab(size int) *Slab { 11 | s := &Slab{orders: make([]OrderNode, size)} 12 | s.free = &s.orders[0] 13 | prev := s.free 14 | for i := 1; i < len(s.orders); i++ { 15 | curr := &s.orders[i] 16 | prev.nextFree = curr 17 | prev = curr 18 | } 19 | return s 20 | } 21 | 22 | func (s *Slab) Malloc() *OrderNode { 23 | o := s.free 24 | if o == nil { 25 | o = &OrderNode{} 26 | } 27 | s.free = o.nextFree 28 | o.nextFree = o // Slab allocated order marker 29 | return o 30 | } 31 | 32 | func (s *Slab) Free(o *OrderNode) { 33 | if o.nextFree == o { 34 | o.nextFree = s.free 35 | s.free = o 36 | } 37 | // OrderNodes that were not slab allocated are left to the garbage collector 38 | } 39 | -------------------------------------------------------------------------------- /q/meddlers.go: -------------------------------------------------------------------------------- 1 | package q 2 | 3 | import ( 4 | "container/list" 5 | "fmt" 6 | "math" 7 | "math/rand" 8 | ) 9 | 10 | type freqDropMeddler struct { 11 | trigger int64 12 | msgCount int64 13 | } 14 | 15 | func NewFreqDropMeddler(trigger int64) *freqDropMeddler { 16 | if trigger < 1 { 17 | trigger = math.MaxInt64 18 | } 19 | return &freqDropMeddler{trigger: trigger, msgCount: 0} 20 | } 21 | 22 | func (m *freqDropMeddler) Meddle(buf *list.List) { 23 | m.msgCount++ 24 | if buf.Len() > 0 && m.msgCount > m.trigger { 25 | buf.Remove(buf.Front()) 26 | m.msgCount = 0 27 | } 28 | } 29 | 30 | type probDropMeddler struct { 31 | prob float64 32 | } 33 | 34 | func NewProbDropMeddler(prob float64) *probDropMeddler { 35 | if prob < 0 || prob > 1 { 36 | panic(fmt.Sprintf("Probability (%f) must be 0.0 <= x <= 1.0", prob)) 37 | } 38 | return &probDropMeddler{prob: prob} 39 | } 40 | 41 | func (m *probDropMeddler) Meddle(buf *list.List) { 42 | if buf.Len() > 0 && m.prob > rand.Float64() { 43 | buf.Remove(buf.Front()) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /matcher/pqueue/public_prioq.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import () 4 | 5 | type MatchQueues struct { 6 | buyTree rbtree 7 | sellTree rbtree 8 | orders rbtree 9 | size int 10 | } 11 | 12 | func (m *MatchQueues) Size() int { 13 | return m.size 14 | } 15 | 16 | func (m *MatchQueues) PushBuy(b *OrderNode) { 17 | m.size++ 18 | m.buyTree.push(&b.priceNode) 19 | m.orders.push(&b.guidNode) 20 | } 21 | 22 | func (m *MatchQueues) PushSell(s *OrderNode) { 23 | m.size++ 24 | m.sellTree.push(&s.priceNode) 25 | m.orders.push(&s.guidNode) 26 | } 27 | 28 | func (m *MatchQueues) PeekBuy() *OrderNode { 29 | return m.buyTree.peekMax().getOrderNode() 30 | } 31 | 32 | func (m *MatchQueues) PeekSell() *OrderNode { 33 | return m.sellTree.peekMin().getOrderNode() 34 | } 35 | 36 | func (m *MatchQueues) PopBuy() *OrderNode { 37 | m.size-- 38 | return m.buyTree.popMax().getOrderNode() 39 | } 40 | 41 | func (m *MatchQueues) PopSell() *OrderNode { 42 | m.size-- 43 | return m.sellTree.popMin().getOrderNode() 44 | } 45 | 46 | func (m *MatchQueues) Cancel(o *OrderNode) *OrderNode { 47 | po := m.orders.cancel(o.Guid()).getOrderNode() 48 | if po != nil { 49 | m.size-- 50 | } 51 | return po 52 | } 53 | -------------------------------------------------------------------------------- /matcher/pqueue/public_refprioq.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import () 4 | 5 | type RefMatchQueues struct { 6 | buys *pqueue 7 | sells *pqueue 8 | size int 9 | } 10 | 11 | func NewRefMatchQueues(lowPrice, highPrice uint64) *RefMatchQueues { 12 | buys := mkPrioq(lowPrice, highPrice) 13 | sells := mkPrioq(lowPrice, highPrice) 14 | return &RefMatchQueues{buys: buys, sells: sells} 15 | } 16 | 17 | func (m *RefMatchQueues) Size() int { 18 | return m.size 19 | } 20 | 21 | func (m *RefMatchQueues) PushBuy(b *OrderNode) { 22 | m.size++ 23 | m.buys.push(b) 24 | } 25 | 26 | func (m *RefMatchQueues) PushSell(s *OrderNode) { 27 | m.size++ 28 | m.sells.push(s) 29 | } 30 | 31 | func (m *RefMatchQueues) PeekBuy() *OrderNode { 32 | return m.buys.peekMax() 33 | } 34 | 35 | func (m *RefMatchQueues) PeekSell() *OrderNode { 36 | return m.sells.peekMin() 37 | } 38 | 39 | func (m *RefMatchQueues) PopBuy() *OrderNode { 40 | m.size-- 41 | return m.buys.popMax() 42 | } 43 | 44 | func (m *RefMatchQueues) PopSell() *OrderNode { 45 | m.size-- 46 | return m.sells.popMin() 47 | } 48 | 49 | func (m *RefMatchQueues) Cancel(o *OrderNode) *OrderNode { 50 | c := m.buys.cancel(o.Guid()) 51 | if c != nil { 52 | m.size-- 53 | return c 54 | } 55 | c = m.sells.cancel(o.Guid()) 56 | if c != nil { 57 | m.size-- 58 | return c 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /msg/serialise.go: -------------------------------------------------------------------------------- 1 | package msg 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | const ( 10 | kindOffset = 0 // 8 bytes 11 | priceOffset = 8 // 8 bytes 12 | amountOffset = 16 // 8 bytes 13 | stockIdOffset = 24 // 8 bytes 14 | traderIdOffset = 32 // 4 bytes 15 | tradeIdOffset = 36 // 4 bytes 16 | ByteSize = 40 17 | ) 18 | 19 | var binCoder = binary.LittleEndian 20 | 21 | // Populate NMessage with *Message values 22 | func (m *Message) Marshal(b []byte) error { 23 | if len(b) != ByteSize { 24 | return errors.New(fmt.Sprintf("Wrong sized byte buffer. Expecting %d, found %d", ByteSize, len(b))) 25 | } 26 | binCoder.PutUint64(b[kindOffset:priceOffset], uint64(m.Kind)) 27 | binCoder.PutUint64(b[priceOffset:amountOffset], uint64(m.Price)) 28 | binCoder.PutUint64(b[amountOffset:stockIdOffset], uint64(m.Amount)) 29 | binCoder.PutUint64(b[stockIdOffset:traderIdOffset], uint64(m.StockId)) 30 | binCoder.PutUint32(b[traderIdOffset:tradeIdOffset], uint32(m.TraderId)) 31 | binCoder.PutUint32(b[tradeIdOffset:], uint32(m.TradeId)) 32 | return nil 33 | } 34 | 35 | // Populate *Message with NMessage values 36 | func (m *Message) Unmarshal(b []byte) error { 37 | if len(b) != ByteSize { 38 | return errors.New(fmt.Sprintf("Wrong sized byte buffer. Expecting %d, found %d", ByteSize, len(b))) 39 | } 40 | m.Kind = MsgKind(binCoder.Uint64(b[kindOffset:priceOffset])) 41 | m.Price = binCoder.Uint64(b[priceOffset:amountOffset]) 42 | m.Amount = binCoder.Uint64(b[amountOffset:stockIdOffset]) 43 | m.StockId = binCoder.Uint64(b[stockIdOffset:traderIdOffset]) 44 | m.TraderId = binCoder.Uint32(b[traderIdOffset:tradeIdOffset]) 45 | m.TradeId = binCoder.Uint32(b[tradeIdOffset:]) 46 | return nil 47 | } 48 | -------------------------------------------------------------------------------- /coordinator/rmsg_serialise.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "fmt" 7 | "github.com/fmstephe/matching_engine/msg" 8 | ) 9 | 10 | const ( 11 | msgOffset = 0 // msg.ByteSize bytes (40) 12 | statusOffset = msg.ByteSize + 0 // 1 byte 13 | directionOffset = msg.ByteSize + 1 // 1 byte 14 | routeOffset = msg.ByteSize + 2 // 1 byte 15 | originIdOffset = msg.ByteSize + 3 // 4 bytes 16 | msgIdOffset = msg.ByteSize + 7 // 4 bytes 17 | rmsgByteSize = msg.ByteSize + 11 // (51) 18 | ) 19 | 20 | var binCoder = binary.LittleEndian 21 | 22 | // Populate NMessage with *Message values 23 | func (rm *RMessage) Marshal(b []byte) error { 24 | if len(b) != rmsgByteSize { 25 | return errors.New(fmt.Sprintf("Wrong sized byte buffer. Expecting %d, found %d", rmsgByteSize, len(b))) 26 | } 27 | (&rm.message).Marshal(b[:msg.ByteSize]) 28 | b[statusOffset] = byte(rm.status) 29 | b[directionOffset] = byte(rm.direction) 30 | b[routeOffset] = byte(rm.route) 31 | binCoder.PutUint32(b[originIdOffset:msgIdOffset], rm.originId) 32 | binCoder.PutUint32(b[msgIdOffset:], rm.msgId) 33 | return nil 34 | } 35 | 36 | // Populate *Message with NMessage values 37 | func (rm *RMessage) Unmarshal(b []byte) error { 38 | if len(b) != rmsgByteSize { 39 | return errors.New(fmt.Sprintf("Wrong sized byte buffer. Expecting %d, found %d", rmsgByteSize, len(b))) 40 | } 41 | (&rm.message).Unmarshal(b[:msg.ByteSize]) 42 | rm.status = MsgStatus(b[statusOffset]) 43 | rm.direction = MsgDirection(b[directionOffset]) 44 | rm.route = MsgRoute(b[routeOffset]) 45 | rm.originId = binCoder.Uint32(b[originIdOffset:msgIdOffset]) 46 | rm.msgId = binCoder.Uint32(b[msgIdOffset:]) 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /bin/svr/client/ttrader_test.go.disabled: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | "testing" 6 | ) 7 | 8 | func runTrader(traderId uint32, t *testing.T) (intoSvr, outOfSvr, orders chan *msg.Message, responses chan *Response, tc traderComm) { 9 | intoSvr = make(chan *msg.Message) 10 | outOfSvr = make(chan *msg.Message) 11 | tdr, tc := newTrader(traderId, intoSvr, outOfSvr) 12 | go tdr.run() 13 | orders = make(chan *msg.Message) 14 | responses = make(chan *Response) 15 | con := connect{traderId: traderId, orders: orders, responses: responses} 16 | tc.connecter <- con 17 | if resp := <-responses; resp.Comment != connectedComment { 18 | t.Errorf("Expecting '" + connectedComment + "' found '" + resp.Comment + "'") 19 | } 20 | return intoSvr, outOfSvr, orders, responses, tc 21 | } 22 | 23 | func TestTraderDisconnect(t *testing.T) { 24 | traderId := uint32(1) 25 | _, _, orders, responses, _ := runTrader(traderId, t) 26 | close(orders) 27 | if resp := <-responses; resp.Comment != ordersClosedComment { 28 | t.Errorf("Expecting '" + ordersClosedComment + "' found '" + resp.Comment + "'") 29 | } 30 | if <-responses != nil { 31 | t.Errorf("Expecting nil response indicating responses had closed") 32 | } 33 | } 34 | 35 | func TestTraderNewConnection(t *testing.T) { 36 | traderId := uint32(1) 37 | _, _, _, responses, tc := runTrader(traderId, t) 38 | newOrders := make(chan *msg.Message) 39 | newResponses := make(chan *Response) 40 | con := connect{traderId: traderId, orders: newOrders, responses: newResponses} 41 | tc.connecter <- con 42 | // Old connection disconnects 43 | if resp := <-responses; resp.Comment != replacedComment { 44 | t.Errorf("Expecting '" + replacedComment + "' found '" + resp.Comment + "'") 45 | } 46 | if <-responses != nil { 47 | t.Errorf("Expecting nil response indicating responses had closed") 48 | } 49 | // New connection connected 50 | if <-newResponses == nil { 51 | t.Errorf("Expecting initial state response, recieved nil") 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /coordinator/tbad_netwk_test.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | . "github.com/fmstephe/matching_engine/msg" 5 | "github.com/fmstephe/matching_engine/q" 6 | "testing" 7 | ) 8 | 9 | const TO_SEND = 1000 10 | 11 | const ( 12 | clientOriginId = iota 13 | serverOriginId = iota 14 | ) 15 | 16 | type echoClient struct { 17 | AppMsgHelper 18 | received []*Message 19 | complete chan bool 20 | } 21 | 22 | func newEchoClient(complete chan bool) *echoClient { 23 | return &echoClient{received: make([]*Message, TO_SEND), complete: complete} 24 | } 25 | 26 | func (c *echoClient) Run() { 27 | go sendAll(c.Out) 28 | m := &Message{} 29 | for { 30 | *m = c.In.Read() 31 | if m.Kind == SHUTDOWN { 32 | return 33 | } 34 | if m != nil { 35 | if c.received[m.TradeId-1] != nil { 36 | panic("Duplicate message received") 37 | } 38 | c.received[m.TradeId-1] = m 39 | if full(c.received) { 40 | c.complete <- true 41 | return 42 | } 43 | } 44 | } 45 | } 46 | 47 | func full(received []*Message) bool { 48 | for _, rm := range received { 49 | if rm == nil { 50 | return false 51 | } 52 | } 53 | return true 54 | } 55 | 56 | func sendAll(out MsgWriter) { 57 | for i := uint32(1); i <= TO_SEND; i++ { 58 | out.Write(Message{Kind: SELL, TraderId: 1, TradeId: i, StockId: 1, Price: 7, Amount: 1}) 59 | } 60 | } 61 | 62 | type echoServer struct { 63 | AppMsgHelper 64 | } 65 | 66 | func (s *echoServer) Run() { 67 | m := &Message{} 68 | for { 69 | *m = s.In.Read() 70 | if m.Kind == SHUTDOWN { 71 | return 72 | } 73 | r := Message{} 74 | r = *m 75 | r.Kind = BUY 76 | s.Out.Write(r) 77 | } 78 | } 79 | 80 | func testBadNetwork(t *testing.T, dropProb float64, cFunc CoordinatorFunc) { 81 | complete := make(chan bool) 82 | c := newEchoClient(complete) 83 | s := &echoServer{} 84 | clientToServer := q.NewMeddleQ("clientToServer", q.NewProbDropMeddler(dropProb)) 85 | serverToClient := q.NewMeddleQ("serverToClient", q.NewProbDropMeddler(dropProb)) 86 | cFunc(serverToClient, clientToServer, c, clientOriginId, "Client", false) 87 | cFunc(clientToServer, serverToClient, s, serverOriginId, "Server", false) 88 | <-complete 89 | } 90 | -------------------------------------------------------------------------------- /matcher/tmatch_test.go: -------------------------------------------------------------------------------- 1 | package matcher 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/coordinator" 5 | "github.com/fmstephe/matching_engine/msg" 6 | "runtime" 7 | "testing" 8 | ) 9 | 10 | const ( 11 | stockId = 1 12 | trader1 = 1 13 | trader2 = 2 14 | trader3 = 3 15 | ) 16 | 17 | var matchMaker = msg.NewMessageMaker(100) 18 | 19 | type responseVals struct { 20 | price uint64 21 | amount uint32 22 | tradeId uint32 23 | stockId uint32 24 | } 25 | 26 | func TestPrice(t *testing.T) { 27 | testPrice(t, 1, 1, 1) 28 | testPrice(t, 2, 1, 1) 29 | testPrice(t, 3, 1, 2) 30 | testPrice(t, 4, 1, 2) 31 | testPrice(t, 5, 1, 3) 32 | testPrice(t, 6, 1, 3) 33 | testPrice(t, 20, 10, 15) 34 | testPrice(t, 21, 10, 15) 35 | testPrice(t, 22, 10, 16) 36 | testPrice(t, 23, 10, 16) 37 | testPrice(t, 24, 10, 17) 38 | testPrice(t, 25, 10, 17) 39 | testPrice(t, 26, 10, 18) 40 | testPrice(t, 27, 10, 18) 41 | testPrice(t, 28, 10, 19) 42 | testPrice(t, 29, 10, 19) 43 | testPrice(t, 30, 10, 20) 44 | } 45 | 46 | func testPrice(t *testing.T, bPrice, sPrice, expected uint64) { 47 | result := price(bPrice, sPrice) 48 | if result != expected { 49 | t.Errorf("price(%d,%d) does not equal %d, got %d instead.", bPrice, sPrice, expected, result) 50 | } 51 | } 52 | 53 | type testerMaker struct { 54 | } 55 | 56 | func (tm *testerMaker) Make() MatchTester { 57 | in := coordinator.NewChanReaderWriter(30) 58 | out := coordinator.NewChanReaderWriter(30) 59 | m := NewMatcher(100) 60 | m.Config("Matcher", in, out) 61 | go m.Run() 62 | return &localTester{in: in, out: out} 63 | } 64 | 65 | type localTester struct { 66 | in coordinator.MsgWriter 67 | out coordinator.MsgReader 68 | } 69 | 70 | func (lt *localTester) Send(t *testing.T, m *msg.Message) { 71 | lt.in.Write(*m) 72 | } 73 | 74 | func (lt *localTester) Expect(t *testing.T, ref *msg.Message) { 75 | m := &msg.Message{} 76 | *m = lt.out.Read() 77 | if *ref != *m { 78 | _, fname, lnum, _ := runtime.Caller(1) 79 | t.Errorf("\nExpecting: %v\nFound: %v\n%s:%d", ref, m, fname, lnum) 80 | } 81 | } 82 | 83 | func (lt *localTester) Cleanup(t *testing.T) {} 84 | 85 | func TestRunTestSuite(t *testing.T) { 86 | RunTestSuite(t, &testerMaker{}) 87 | } 88 | -------------------------------------------------------------------------------- /msg/tserialise_test.go: -------------------------------------------------------------------------------- 1 | package msg 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func messageBuffer() []byte { 8 | return make([]byte, ByteSize) 9 | } 10 | 11 | func TestMarshallDoesNotDestroyMesssage(t *testing.T) { 12 | ref := &Message{Kind: 1, Price: 2, Amount: 3, StockId: 4, TraderId: 5, TradeId: 6} 13 | m1 := &Message{} 14 | *m1 = *ref 15 | b := messageBuffer() 16 | if err := m1.Marshal(b); err != nil { 17 | t.Errorf("Unexpected marshalling error %s", err.Error()) 18 | } 19 | assertEquivalent(t, ref, m1, b) 20 | } 21 | 22 | func TestMarshallUnMarshalPairsProducesSameMessage(t *testing.T) { 23 | m1 := &Message{Kind: 1, Price: 2, Amount: 3, StockId: 4, TraderId: 5, TradeId: 6} 24 | b := messageBuffer() 25 | if err := m1.Marshal(b); err != nil { 26 | t.Errorf("Unexpected marshalling error %s", err.Error()) 27 | } 28 | m2 := &Message{} 29 | if err := m2.Unmarshal(b); err != nil { 30 | t.Errorf("Unexpected unmarshalling error %s", err.Error()) 31 | } 32 | assertEquivalent(t, m1, m2, b) 33 | } 34 | 35 | func assertEquivalent(t *testing.T, exp, fnd *Message, b []byte) { 36 | if *fnd != *exp { 37 | t.Errorf("\nExpected to find %v\nfound %v\nMarshalled from %v", exp, fnd, b) 38 | } 39 | } 40 | 41 | func TestMarshalWithSmallBufferErrors(t *testing.T) { 42 | m1 := &Message{Kind: 1, Price: 2, Amount: 3, StockId: 4, TraderId: 5, TradeId: 6} 43 | b := make([]byte, ByteSize-1) 44 | if err := m1.Marshal(b); err == nil { 45 | t.Error("Expected marshalling error. Found none") 46 | } 47 | } 48 | 49 | func TestMarshalWithLargeBufferErrors(t *testing.T) { 50 | m1 := &Message{Kind: 1, Price: 2, Amount: 3, StockId: 4, TraderId: 5, TradeId: 6} 51 | b := make([]byte, ByteSize+1) 52 | if err := m1.Marshal(b); err == nil { 53 | t.Error("Expected marshalling error. Found none") 54 | } 55 | } 56 | 57 | func TestUnmarshalWithSmallBufferErrors(t *testing.T) { 58 | m1 := &Message{} 59 | b := make([]byte, ByteSize-1) 60 | if err := m1.Unmarshal(b); err == nil { 61 | t.Error("Expected marshalling error. Found none") 62 | } 63 | } 64 | 65 | func TestUnmarshalWithLargeBufferErrors(t *testing.T) { 66 | m1 := &Message{} 67 | b := make([]byte, ByteSize+1) 68 | if err := m1.Unmarshal(b); err == nil { 69 | t.Error("Expected marshalling error. Found none") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /q/meddle_q.go: -------------------------------------------------------------------------------- 1 | package q 2 | 3 | import ( 4 | "container/list" 5 | "runtime" 6 | ) 7 | 8 | type Meddler interface { 9 | Meddle(*list.List) 10 | } 11 | 12 | type notMeddler struct{} 13 | 14 | func (m *notMeddler) Meddle(l *list.List) {} 15 | 16 | type meddleQ struct { 17 | name string 18 | writeChan chan []byte 19 | readChan chan []byte 20 | shutdown chan bool 21 | buf *list.List 22 | meddler Meddler 23 | } 24 | 25 | func NewMeddleQ(name string, meddler Meddler) *meddleQ { 26 | q := &meddleQ{ 27 | name: name, 28 | writeChan: make(chan []byte, 100), 29 | readChan: make(chan []byte, 100), 30 | shutdown: make(chan bool), 31 | buf: list.New(), 32 | meddler: meddler} 33 | go q.run() 34 | return q 35 | } 36 | 37 | func NewSimpleQ(name string) *meddleQ { 38 | q := &meddleQ{name: name, 39 | writeChan: make(chan []byte, 100), 40 | readChan: make(chan []byte, 100), 41 | shutdown: make(chan bool), 42 | buf: list.New(), 43 | meddler: ¬Meddler{}} 44 | go q.run() 45 | return q 46 | } 47 | 48 | func (q *meddleQ) Read(p []byte) (int, error) { 49 | c := <-q.readChan 50 | copy(p, c) 51 | if len(p) < len(c) { 52 | return len(p), nil 53 | } 54 | return len(c), nil 55 | } 56 | 57 | func (q *meddleQ) Close() error { 58 | q.shutdown <- true 59 | return nil 60 | } 61 | 62 | func (q *meddleQ) Write(p []byte) (int, error) { 63 | c := make([]byte, len(p)) 64 | copy(c, p) 65 | q.writeChan <- c 66 | return len(c), nil 67 | } 68 | 69 | func (q *meddleQ) run() { 70 | for { 71 | q.read() 72 | q.meddler.Meddle(q.buf) 73 | q.write() 74 | select { 75 | case <-q.shutdown: 76 | return 77 | default: 78 | } 79 | } 80 | } 81 | 82 | func (q *meddleQ) read() { 83 | if q.buf.Len() == 0 { 84 | r := <-q.writeChan 85 | q.buf.PushBack(r) 86 | } else { 87 | select { 88 | case r := <-q.writeChan: 89 | q.buf.PushBack(r) 90 | default: 91 | runtime.Gosched() 92 | } 93 | } 94 | } 95 | 96 | func (q *meddleQ) write() { 97 | if q.buf.Len() > 0 { 98 | head := q.buf.Front() 99 | val := head.Value.([]byte) 100 | select { 101 | case q.readChan <- val: 102 | q.buf.Remove(head) 103 | default: 104 | runtime.Gosched() 105 | } 106 | 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /matcher/pqueue/refprioq.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import () 4 | 5 | // An easy to build priority queue 6 | type pqueue struct { 7 | prios [][]*OrderNode 8 | lowPrice, highPrice uint64 9 | } 10 | 11 | func mkPrioq(lowPrice, highPrice uint64) *pqueue { 12 | prios := make([][]*OrderNode, highPrice-lowPrice+1) 13 | return &pqueue{prios: prios, lowPrice: lowPrice, highPrice: highPrice} 14 | } 15 | 16 | func (q *pqueue) push(o *OrderNode) { 17 | idx := o.Price() - q.lowPrice 18 | prio := q.prios[idx] 19 | prio = append(prio, o) 20 | q.prios[idx] = prio 21 | } 22 | 23 | func (q *pqueue) peekMax() *OrderNode { 24 | if len(q.prios) == 0 { 25 | return nil 26 | } 27 | for i := len(q.prios) - 1; i >= 0; i-- { 28 | switch { 29 | case len(q.prios[i]) > 0: 30 | return q.prios[i][0] 31 | default: 32 | continue 33 | } 34 | } 35 | return nil 36 | } 37 | 38 | func (q *pqueue) popMax() *OrderNode { 39 | if len(q.prios) == 0 { 40 | return nil 41 | } 42 | for i := len(q.prios) - 1; i >= 0; i-- { 43 | switch { 44 | case len(q.prios[i]) > 0: 45 | return q.pop(i) 46 | default: 47 | continue 48 | } 49 | } 50 | return nil 51 | } 52 | 53 | func (q *pqueue) peekMin() *OrderNode { 54 | if len(q.prios) == 0 { 55 | return nil 56 | } 57 | for i := 0; i < len(q.prios); i++ { 58 | switch { 59 | case len(q.prios[i]) > 0: 60 | return q.prios[i][0] 61 | default: 62 | continue 63 | } 64 | } 65 | return nil 66 | } 67 | 68 | func (q *pqueue) popMin() *OrderNode { 69 | if len(q.prios) == 0 { 70 | return nil 71 | } 72 | for i := 0; i < len(q.prios); i++ { 73 | switch { 74 | case len(q.prios[i]) > 0: 75 | return q.pop(i) 76 | default: 77 | continue 78 | } 79 | } 80 | return nil 81 | } 82 | 83 | func (q *pqueue) pop(i int) *OrderNode { 84 | prio := q.prios[i] 85 | o := prio[0] 86 | prio = prio[1:] 87 | q.prios[i] = prio 88 | return o 89 | } 90 | 91 | func (q *pqueue) cancel(guid uint64) *OrderNode { 92 | for i := range q.prios { 93 | priceQ := q.prios[i] 94 | for j := range priceQ { 95 | o := priceQ[j] 96 | if o.Guid() == guid { 97 | priceQ = append(priceQ[0:j], priceQ[j+1:]...) 98 | q.prios[i] = priceQ 99 | return o 100 | } 101 | } 102 | } 103 | return nil 104 | } 105 | -------------------------------------------------------------------------------- /matcher/pqueue/order.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fmstephe/flib/fmath" 6 | "github.com/fmstephe/flib/fstrconv" 7 | "github.com/fmstephe/matching_engine/msg" 8 | ) 9 | 10 | type OrderNode struct { 11 | priceNode node 12 | guidNode node 13 | amount uint64 14 | stockId uint64 15 | kind msg.MsgKind 16 | nextFree *OrderNode 17 | } 18 | 19 | func (o *OrderNode) CopyFrom(from *msg.Message) { 20 | o.amount = from.Amount 21 | o.stockId = from.StockId 22 | o.kind = from.Kind 23 | o.setup(from.Price, uint64(fmath.CombineInt32(int32(from.TraderId), int32(from.TradeId)))) 24 | } 25 | 26 | func (o *OrderNode) CopyTo(to *msg.Message) { 27 | to.Kind = o.Kind() 28 | to.Price = o.Price() 29 | to.Amount = o.Amount() 30 | to.TraderId = o.TraderId() 31 | to.TradeId = o.TradeId() 32 | to.StockId = o.StockId() 33 | } 34 | 35 | func (o *OrderNode) setup(price, guid uint64) { 36 | initNode(o, price, &o.priceNode, &o.guidNode) 37 | initNode(o, guid, &o.guidNode, &o.priceNode) 38 | } 39 | 40 | func (o *OrderNode) Price() uint64 { 41 | return o.priceNode.val 42 | } 43 | 44 | func (o *OrderNode) Guid() uint64 { 45 | return o.guidNode.val 46 | } 47 | 48 | func (o *OrderNode) TraderId() uint32 { 49 | return uint32(fmath.HighInt32(int64(o.guidNode.val))) 50 | } 51 | 52 | func (o *OrderNode) TradeId() uint32 { 53 | return uint32(fmath.LowInt32(int64(o.guidNode.val))) 54 | } 55 | 56 | func (o *OrderNode) Amount() uint64 { 57 | return o.amount 58 | } 59 | 60 | func (o *OrderNode) ReduceAmount(s uint64) { 61 | o.amount -= s 62 | } 63 | 64 | func (o *OrderNode) StockId() uint64 { 65 | return o.stockId 66 | } 67 | 68 | func (o *OrderNode) Kind() msg.MsgKind { 69 | return o.kind 70 | } 71 | 72 | func (o *OrderNode) Remove() { 73 | o.priceNode.pop() 74 | o.guidNode.pop() 75 | } 76 | 77 | func (o *OrderNode) String() string { 78 | if o == nil { 79 | return "" 80 | } 81 | price := fstrconv.ItoaDelim(int64(o.Price()), ',') 82 | amount := fstrconv.ItoaDelim(int64(o.Amount()), ',') 83 | traderId := fstrconv.ItoaDelim(int64(o.TraderId()), '-') 84 | tradeId := fstrconv.ItoaDelim(int64(o.TradeId()), '-') 85 | stockId := fstrconv.ItoaDelim(int64(o.StockId()), '-') 86 | kind := o.kind 87 | return fmt.Sprintf("%v, price %s, amount %s, trader %s, trade %s, stock %s", kind, price, amount, traderId, tradeId, stockId) 88 | } 89 | -------------------------------------------------------------------------------- /matcher/tcompare_test.go: -------------------------------------------------------------------------------- 1 | package matcher 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/coordinator" 5 | "github.com/fmstephe/matching_engine/msg" 6 | "testing" 7 | ) 8 | 9 | var cmprMaker = msg.NewMessageMaker(1) 10 | 11 | func TestCompareMatchers(t *testing.T) { 12 | compareMatchers(t, 100, 1, 1, 1) 13 | compareMatchers(t, 100, 10, 1, 1) 14 | // 15 | compareMatchers(t, 100, 1, 1, 2) 16 | compareMatchers(t, 100, 10, 1, 2) 17 | compareMatchers(t, 100, 100, 1, 2) 18 | // 19 | compareMatchers(t, 100, 1, 10, 20) 20 | compareMatchers(t, 100, 10, 10, 20) 21 | compareMatchers(t, 100, 100, 10, 20) 22 | // 23 | compareMatchers(t, 100, 1, 100, 2000) 24 | compareMatchers(t, 100, 10, 100, 2000) 25 | compareMatchers(t, 100, 100, 100, 2000) 26 | } 27 | 28 | func compareMatchers(t *testing.T, orderPairs, depth int, lowPrice, highPrice uint64) { 29 | refIn := coordinator.NewChanReaderWriter(1) 30 | refOut := coordinator.NewChanReaderWriter(orderPairs * 4) 31 | refm := newRefmatcher(lowPrice, highPrice) 32 | refm.Config("Reference Matcher", refIn, refOut) 33 | in := coordinator.NewChanReaderWriter(1) 34 | out := coordinator.NewChanReaderWriter(orderPairs * 4) 35 | m := NewMatcher(orderPairs * 4) 36 | m.Config("Real Matcher", in, out) 37 | testSet, err := cmprMaker.RndTradeSet(orderPairs, depth, lowPrice, highPrice) 38 | if err != nil { 39 | panic(err.Error()) 40 | } 41 | go m.Run() 42 | go refm.Run() 43 | for i := 0; i < len(testSet); i++ { 44 | refIn.Write(testSet[i]) 45 | in.Write(testSet[i]) 46 | } 47 | refIn.Write(msg.Message{Kind: msg.SHUTDOWN}) 48 | in.Write(msg.Message{Kind: msg.SHUTDOWN}) 49 | checkBuffers(t, refOut, out) 50 | } 51 | 52 | func checkBuffers(t *testing.T, refrc, rc coordinator.MsgReader) { 53 | refrs := drain(refrc) 54 | rs := drain(rc) 55 | if len(refrs) != len(rs) { 56 | t.Errorf("Different number of writes detected. Simple: %d, Real: %d", len(refrs), len(rs)) 57 | return 58 | } 59 | for i := 0; i < len(rs); i++ { 60 | refr := refrs[i] 61 | r := rs[i] 62 | if *refr != *r { 63 | t.Errorf("Different responses read. Simple: %v, Real: %v", refr, r) 64 | return 65 | } 66 | } 67 | } 68 | 69 | func drain(r coordinator.MsgReader) []*msg.Message { 70 | ms := make([]*msg.Message, 0) 71 | for { 72 | m := &msg.Message{} 73 | *m = r.Read() 74 | ms = append(ms, m) 75 | if m.Kind == msg.SHUTDOWN { 76 | return ms 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /coordinator/trmsg_serialise_test.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | "testing" 6 | ) 7 | 8 | func messageBuffer() []byte { 9 | return make([]byte, rmsgByteSize) 10 | } 11 | 12 | func TestMarshallDoesNotDestroyMesssage(t *testing.T) { 13 | m := msg.Message{Kind: 7, Price: 8, Amount: 9, StockId: 10, TraderId: 11, TradeId: 12} 14 | ref := &RMessage{route: APP, direction: IN, originId: 5, msgId: 6, message: m} 15 | rm1 := &RMessage{} 16 | *rm1 = *ref 17 | b := messageBuffer() 18 | if err := rm1.Marshal(b); err != nil { 19 | t.Errorf("Unexpected marshalling error %s", err.Error()) 20 | } 21 | assertEquivalent(t, ref, rm1, b) 22 | } 23 | 24 | func TestMarshallUnMarshalPairsProducesSameRMessage(t *testing.T) { 25 | m1 := msg.Message{Kind: 7, Price: 8, Amount: 9, StockId: 10, TraderId: 11, TradeId: 12} 26 | rm1 := &RMessage{route: APP, direction: IN, originId: 5, msgId: 6, message: m1} 27 | b := messageBuffer() 28 | if err := rm1.Marshal(b); err != nil { 29 | t.Errorf("Unexpected marshalling error %s", err.Error()) 30 | } 31 | rm2 := &RMessage{} 32 | if err := rm2.Unmarshal(b); err != nil { 33 | t.Errorf("Unexpected unmarshalling error %s", err.Error()) 34 | } 35 | assertEquivalent(t, rm1, rm2, b) 36 | } 37 | 38 | func assertEquivalent(t *testing.T, exp, fnd *RMessage, b []byte) { 39 | if *fnd != *exp { 40 | t.Errorf("\nExpected to find %v\nfound %v\nMarshalled from %v", exp, fnd, b) 41 | } 42 | } 43 | 44 | func TestMarshalWithSmallBufferErrors(t *testing.T) { 45 | m1 := msg.Message{Kind: 3, Price: 4, Amount: 5, StockId: 6, TraderId: 7, TradeId: 8} 46 | rm1 := &RMessage{route: APP, direction: IN, originId: 1, msgId: 2, message: m1} 47 | b := make([]byte, rmsgByteSize-1) 48 | if err := rm1.Marshal(b); err == nil { 49 | t.Error("Expected marshalling error. Found none") 50 | } 51 | } 52 | 53 | func TestMarshalWithLargeBufferErrors(t *testing.T) { 54 | m1 := msg.Message{Kind: 3, Price: 4, Amount: 5, StockId: 6, TraderId: 7, TradeId: 8} 55 | rm1 := &RMessage{route: APP, direction: IN, originId: 1, msgId: 2, message: m1} 56 | b := make([]byte, rmsgByteSize+1) 57 | if err := rm1.Marshal(b); err == nil { 58 | t.Error("Expected marshalling error. Found none") 59 | } 60 | } 61 | 62 | func TestUnmarshalWithSmallBufferErrors(t *testing.T) { 63 | rm1 := &RMessage{} 64 | b := make([]byte, rmsgByteSize-1) 65 | if err := rm1.Unmarshal(b); err == nil { 66 | t.Error("Expected marshalling error. Found none") 67 | } 68 | } 69 | 70 | func TestUnmarshalWithLargeBufferErrors(t *testing.T) { 71 | rm1 := &RMessage{} 72 | b := make([]byte, rmsgByteSize+1) 73 | if err := rm1.Unmarshal(b); err == nil { 74 | t.Error("Expected marshalling error. Found none") 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /bin/svr/client/trader.go.disabled: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | ) 6 | 7 | const ( 8 | connectedComment = "Connected to trader" 9 | ordersClosedComment = "Disconnected because orders channel was closed" 10 | replacedComment = "Disconnected because trader received a new connection" 11 | shutdownComment = "Disconnected because trader is shutting down" 12 | ) 13 | 14 | // Temporary constant while we are creating new traders when a connection is established 15 | const initialBalance = 100 16 | 17 | // Temporary function while we are creating new traders when a connection is established 18 | func initialStocks() map[uint64]uint64 { 19 | return map[uint64]uint64{1: 10, 2: 10, 3: 10} 20 | } 21 | 22 | type trader struct { 23 | balance *balanceManager 24 | // Communication with external system, e.g. a websocket connection 25 | orders chan *msg.Message 26 | responses chan *Response 27 | // Communication with internal trader server 28 | intoSvr chan *msg.Message 29 | outOfSvr chan *msg.Message 30 | connecter chan connect 31 | } 32 | 33 | func newTrader(traderId uint32, intoSvr, outOfSvr chan *msg.Message) (*trader, traderComm) { 34 | balance := newBalanceManager(traderId, initialBalance, initialStocks()) 35 | connecter := make(chan connect) 36 | t := &trader{balance: balance, intoSvr: intoSvr, outOfSvr: outOfSvr, connecter: connecter} 37 | tc := traderComm{outOfSvr: outOfSvr, connecter: connecter} 38 | return t, tc 39 | } 40 | 41 | func (t *trader) run() { 42 | defer t.shutdown() 43 | for { 44 | select { 45 | case con := <-t.connecter: 46 | t.connect(con) 47 | case m := <-t.orders: 48 | if m == nil { // channel has been closed 49 | t.disconnect(ordersClosedComment) 50 | continue 51 | } 52 | accepted := t.balance.process(m) 53 | t.sendResponse(m, accepted, "") 54 | if accepted { 55 | t.intoSvr <- m 56 | } 57 | case m := <-t.outOfSvr: 58 | accepted := t.balance.process(m) 59 | t.sendResponse(m, accepted, "") 60 | } 61 | } 62 | } 63 | 64 | // TODO currently trader never shuts down. How do we want to deal with this? 65 | func (t *trader) shutdown() { 66 | t.disconnect(shutdownComment) 67 | } 68 | 69 | func (t *trader) connect(con connect) { 70 | t.disconnect(replacedComment) 71 | t.orders = con.orders 72 | t.responses = con.responses 73 | // Send a hello state message 74 | t.sendResponse(&msg.Message{}, true, connectedComment) 75 | } 76 | 77 | func (t *trader) disconnect(comment string) { 78 | if t.responses != nil { 79 | t.sendResponse(&msg.Message{}, true, comment) 80 | close(t.responses) 81 | } 82 | t.responses = nil 83 | t.orders = nil 84 | } 85 | 86 | func (t *trader) sendResponse(m *msg.Message, accepted bool, comment string) { 87 | if t.responses != nil { 88 | r := t.balance.makeResponse(m, accepted, comment) 89 | t.responses <- r 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /bin/svr/main.go.disabled: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "code.google.com/p/go.net/websocket" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/fmstephe/matching_engine/bin/svr/client" 8 | "github.com/fmstephe/matching_engine/coordinator" 9 | "github.com/fmstephe/matching_engine/matcher" 10 | "github.com/fmstephe/matching_engine/msg" 11 | "github.com/fmstephe/matching_engine/q" 12 | "github.com/fmstephe/simpleid" 13 | "io" 14 | "net/http" 15 | "os" 16 | ) 17 | 18 | var traderMaker *client.TraderMaker 19 | var idMaker = simpleid.NewIdMaker() 20 | 21 | const ( 22 | clientOriginId = iota 23 | serverOriginId = iota 24 | ) 25 | 26 | func main() { 27 | pwd, err := os.Getwd() 28 | if err != nil { 29 | println(err.Error()) 30 | return 31 | } 32 | // Create matching engine + client 33 | clientToServer := q.NewSimpleQ("Client To Server") 34 | serverToClient := q.NewSimpleQ("Server To Client") 35 | // Matching Engine 36 | m := matcher.NewMatcher(100) 37 | var clientSvr *client.Server 38 | clientSvr, traderMaker = client.NewServer() 39 | coordinator.InMemory(serverToClient, clientToServer, clientSvr, clientOriginId, "Client.........", true) 40 | coordinator.InMemory(clientToServer, serverToClient, m, serverOriginId, "Matching Engine", true) 41 | http.Handle("/wsconn", websocket.Handler(handleTrader)) 42 | http.Handle("/", http.FileServer(http.Dir(pwd+"/html/"))) 43 | if err := http.ListenAndServe("127.0.0.1:8081", nil); err != nil { 44 | println(err.Error()) 45 | } 46 | } 47 | 48 | func handleTrader(ws *websocket.Conn) { 49 | traderId := uint32(idMaker.Id()) // NB: A fussy man would check that the id generated fitted inside uint32 50 | orders, responses := traderMaker.Make(traderId) 51 | go reader(ws, traderId, orders) 52 | writer(ws, traderId, responses) 53 | } 54 | 55 | func reader(ws *websocket.Conn, traderId uint32, orders chan<- *msg.Message) { 56 | defer close(orders) 57 | defer ws.Close() 58 | for { 59 | var data string 60 | if err := websocket.Message.Receive(ws, &data); err != nil { 61 | logError(traderId, err) 62 | return 63 | } 64 | m := &msg.Message{} 65 | if err := json.Unmarshal([]byte(data), m); err != nil { 66 | logError(traderId, err) 67 | return 68 | } 69 | m.TraderId = traderId 70 | println("WebSocket......: " + m.String()) 71 | orders <- m 72 | } 73 | } 74 | 75 | func writer(ws *websocket.Conn, traderId uint32, responses chan *client.Response) { 76 | defer ws.Close() 77 | for r := range responses { 78 | b, err := json.Marshal(r) 79 | if err != nil { 80 | logError(traderId, err) 81 | return 82 | } 83 | if _, err = ws.Write(b); err != nil { 84 | logError(traderId, err) 85 | return 86 | } 87 | } 88 | } 89 | 90 | func logError(traderId uint32, err error) { 91 | if err == io.EOF { 92 | println(fmt.Sprintf("Closing connection for trader %d", traderId)) 93 | } else { 94 | println(fmt.Sprintf("Error for trader %d: %s", traderId, err.Error())) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /matcher/tcoordinator_test.go: -------------------------------------------------------------------------------- 1 | package matcher 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/coordinator" 5 | "github.com/fmstephe/matching_engine/msg" 6 | "net" 7 | "runtime" 8 | "strconv" 9 | "testing" 10 | ) 11 | 12 | // Because we are communicating via UDP, messages could arrive out of order, in practice they travel in-order via localhost 13 | 14 | type netwkTesterMaker struct { 15 | ip [4]byte 16 | freePort int 17 | } 18 | 19 | func newMatchTesterMaker() MatchTesterMaker { 20 | return &netwkTesterMaker{ip: [4]byte{127, 0, 0, 1}, freePort: 1201} 21 | } 22 | 23 | func (tm *netwkTesterMaker) Make() MatchTester { 24 | serverPort := tm.freePort 25 | tm.freePort++ 26 | clientPort := tm.freePort 27 | tm.freePort++ 28 | // Build matcher 29 | m := NewMatcher(100) 30 | coordinator.InMemory(mkReadConn(serverPort), mkWriteConn(clientPort), m, 0, "Matching Engine", false) 31 | // Build client 32 | fromListener, toResponder := coordinator.InMemoryListenerResponder(mkReadConn(clientPort), mkWriteConn(serverPort), "Test Client ", false) 33 | return &netwkTester{receivedMsgs: fromListener, toSendMsgs: toResponder} 34 | } 35 | 36 | type netwkTester struct { 37 | receivedMsgs coordinator.MsgReader 38 | toSendMsgs coordinator.MsgWriter 39 | } 40 | 41 | func (nt *netwkTester) Send(t *testing.T, m *msg.Message) { 42 | nt.toSendMsgs.Write(*m) 43 | } 44 | 45 | func (nt *netwkTester) Expect(t *testing.T, e *msg.Message) { 46 | r := &msg.Message{} 47 | *r = nt.receivedMsgs.Read() 48 | validate(t, r, e, 2) 49 | } 50 | 51 | func (nt *netwkTester) ExpectOneOf(t *testing.T, es ...*msg.Message) { 52 | r := &msg.Message{} 53 | *r = nt.receivedMsgs.Read() 54 | for _, e := range es { 55 | if *e == *r { 56 | return 57 | } 58 | } 59 | t.Errorf("Expecting one of %v, received %v instead", es, r) 60 | } 61 | 62 | func (nt *netwkTester) Cleanup(t *testing.T) { 63 | nt.toSendMsgs.Write(msg.Message{Kind: msg.SHUTDOWN}) 64 | } 65 | 66 | func mkWriteConn(port int) *net.UDPConn { 67 | addr, err := net.ResolveUDPAddr("udp", ":"+strconv.Itoa(port)) 68 | if err != nil { 69 | panic(err) 70 | } 71 | conn, err := net.DialUDP("udp", nil, addr) 72 | if err != nil { 73 | panic(err) 74 | } 75 | return conn 76 | } 77 | 78 | func mkReadConn(port int) *net.UDPConn { 79 | addr, err := net.ResolveUDPAddr("udp", ":"+strconv.Itoa(port)) 80 | if err != nil { 81 | panic(err) 82 | } 83 | conn, err := net.ListenUDP("udp", addr) 84 | if err != nil { 85 | panic(err) 86 | } 87 | return conn 88 | } 89 | 90 | func validate(t *testing.T, m, e *msg.Message, stackOffset int) { 91 | if *m != *e { 92 | _, fname, lnum, _ := runtime.Caller(stackOffset) 93 | t.Errorf("\nExpecting: %v\nFound: %v \n%s:%d", e, m, fname, lnum) 94 | } 95 | } 96 | 97 | func TestRunCoordinatedTestSuite(t *testing.T) { 98 | RunTestSuite(t, newMatchTesterMaker()) 99 | } 100 | -------------------------------------------------------------------------------- /itch/reader.go: -------------------------------------------------------------------------------- 1 | package itch 2 | 3 | import ( 4 | "bufio" 5 | "github.com/fmstephe/matching_engine/msg" 6 | "io" 7 | "math" 8 | "os" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type ItchReader struct { 14 | lineCount uint 15 | maxBuy uint64 16 | minSell uint64 17 | r *bufio.Reader 18 | } 19 | 20 | func NewItchReader(fName string) *ItchReader { 21 | f, err := os.Open(fName) 22 | if err != nil { 23 | panic(err.Error()) 24 | } 25 | r := bufio.NewReader(f) 26 | // Clear column headers 27 | if _, err := r.ReadString('\n'); err != nil { 28 | panic(err.Error()) 29 | } 30 | return &ItchReader{lineCount: 1, minSell: math.MaxInt32, r: r} 31 | } 32 | 33 | func (i *ItchReader) ReadMessage() (o *msg.Message, line string, err error) { 34 | i.lineCount++ 35 | for { 36 | line, err = i.r.ReadString('\n') 37 | if err != nil { 38 | return 39 | } 40 | if line != "" { 41 | break 42 | } 43 | } 44 | o, err = mkMessage(line) 45 | if o != nil && o.Kind == msg.BUY && o.Price > i.maxBuy { 46 | i.maxBuy = o.Price 47 | } 48 | if o != nil && o.Kind == msg.SELL && o.Price < i.minSell { 49 | i.minSell = o.Price 50 | } 51 | return 52 | } 53 | 54 | func (i *ItchReader) ReadAll() (orders []*msg.Message, err error) { 55 | orders = make([]*msg.Message, 0) 56 | var o *msg.Message 57 | for err == nil { 58 | o, _, err = i.ReadMessage() 59 | if o != nil { 60 | orders = append(orders, o) 61 | } 62 | } 63 | if err == io.EOF { 64 | err = nil 65 | } 66 | return 67 | } 68 | 69 | func (i *ItchReader) LineCount() uint { 70 | return i.lineCount 71 | } 72 | 73 | func (i *ItchReader) MaxBuy() uint64 { 74 | return i.maxBuy 75 | } 76 | 77 | func (i *ItchReader) MinSell() uint64 { 78 | return i.minSell 79 | } 80 | 81 | func mkMessage(line string) (o *msg.Message, err error) { 82 | ss := strings.Split(line, " ") 83 | var useful []string 84 | for _, w := range ss { 85 | if w != "" && w != "\n" { 86 | useful = append(useful, w) 87 | } 88 | } 89 | m, err := mkData(useful) 90 | *o = *m 91 | if err != nil { 92 | return 93 | } 94 | switch useful[3] { 95 | case "B": 96 | o.Kind = msg.BUY 97 | return 98 | case "S": 99 | o.Kind = msg.SELL 100 | return 101 | case "D": 102 | o.WriteCancelFor(o) 103 | return 104 | default: 105 | return 106 | } 107 | panic("Unreachable") 108 | } 109 | 110 | func mkData(useful []string) (m *msg.Message, err error) { 111 | // print("ID: ", useful[2], " Type: ", useful[3], " Price: ", useful[4], " Amount: ", useful[5]) 112 | // println() 113 | var price, amount, traderId, tradeId int 114 | amount, err = strconv.Atoi(useful[4]) 115 | price, err = strconv.Atoi(useful[5]) 116 | traderId, err = strconv.Atoi(useful[2]) 117 | tradeId, err = strconv.Atoi(useful[2]) 118 | if err != nil { 119 | return nil, err 120 | } 121 | return &msg.Message{Price: uint64(price), Amount: uint64(amount), TraderId: uint32(traderId), TradeId: uint32(tradeId), StockId: uint64(1)}, nil 122 | } 123 | -------------------------------------------------------------------------------- /coordinator/rmsg.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fmstephe/matching_engine/msg" 6 | ) 7 | 8 | // TODO this can probably be removed. If we log all reads but don't pass the malformed ones into the application then this is redundant. 9 | type MsgStatus byte 10 | 11 | const ( 12 | NORMAL = MsgStatus(iota) 13 | INVALID_MSG_ERROR = MsgStatus(iota) 14 | READ_ERROR = MsgStatus(iota) 15 | SMALL_READ_ERROR = MsgStatus(iota) 16 | WRITE_ERROR = MsgStatus(iota) 17 | SMALL_WRITE_ERROR = MsgStatus(iota) 18 | NUM_OF_STATUS = int32(iota) 19 | ) 20 | 21 | func (s MsgStatus) String() string { 22 | switch s { 23 | case NORMAL: 24 | return "NORMAL" 25 | case INVALID_MSG_ERROR: 26 | return "INVALID_MSG_ERROR" 27 | case READ_ERROR: 28 | return "READ_ERROR" 29 | case SMALL_READ_ERROR: 30 | return "SMALL_READ_ERROR" 31 | case WRITE_ERROR: 32 | return "WRITE_ERROR" 33 | case SMALL_WRITE_ERROR: 34 | return "SMALL_WRITE_ERROR" 35 | } 36 | panic("Bad Value") 37 | } 38 | 39 | type MsgDirection byte 40 | 41 | const ( 42 | NO_DIRECTION = MsgDirection(iota) 43 | OUT = MsgDirection(iota) 44 | IN = MsgDirection(iota) 45 | ) 46 | 47 | func (d MsgDirection) String() string { 48 | switch d { 49 | case NO_DIRECTION: 50 | return "NO_DIRECTION" 51 | case IN: 52 | return "IN" 53 | case OUT: 54 | return "OUT" 55 | } 56 | panic("Bad Value") 57 | } 58 | 59 | type MsgRoute byte 60 | 61 | const ( 62 | NO_ROUTE = MsgRoute(iota) 63 | APP = MsgRoute(iota) 64 | ACK = MsgRoute(iota) 65 | NUM_OF_ROUTE = int32(iota) 66 | ) 67 | 68 | func (r MsgRoute) String() string { 69 | switch r { 70 | case NO_ROUTE: 71 | return "NO_ROUTE" 72 | case APP: 73 | return "APP" 74 | case ACK: 75 | return "ACK" 76 | } 77 | panic("Bad Value") 78 | } 79 | 80 | // Flat description of an incoming message 81 | type RMessage struct { 82 | // Body 83 | message msg.Message 84 | // Headers 85 | status MsgStatus 86 | direction MsgDirection 87 | route MsgRoute 88 | originId uint32 89 | msgId uint32 90 | } 91 | 92 | func (rm *RMessage) Valid() bool { 93 | // A message must always have a direction 94 | if rm.direction == NO_DIRECTION { 95 | return false 96 | } 97 | // Any message in an error status is valid 98 | if rm.status != NORMAL { 99 | return true 100 | } 101 | // Zero values for origin and Id are not valid 102 | if rm.originId == 0 || rm.msgId == 0 { 103 | return false 104 | } 105 | return rm.message.Valid() 106 | } 107 | 108 | func (rm *RMessage) WriteAckFor(orm *RMessage) { 109 | *rm = *orm 110 | rm.route = ACK 111 | rm.direction = OUT 112 | } 113 | 114 | func (rm *RMessage) String() string { 115 | if rm == nil { 116 | return "" 117 | } 118 | status := "" 119 | if rm.status != NORMAL { 120 | status = rm.status.String() + "! " 121 | } 122 | return fmt.Sprintf("(%s%v %v, %d, %d), %s", status, rm.direction, rm.route, rm.originId, rm.msgId, rm.message.String()) 123 | } 124 | -------------------------------------------------------------------------------- /msg/msg.go: -------------------------------------------------------------------------------- 1 | package msg 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fmstephe/flib/fstrconv" 6 | "unsafe" 7 | ) 8 | 9 | type MsgKind uint64 10 | 11 | const ( 12 | NO_KIND = MsgKind(iota) 13 | BUY = MsgKind(iota) 14 | SELL = MsgKind(iota) 15 | CANCEL = MsgKind(iota) 16 | PARTIAL = MsgKind(iota) 17 | FULL = MsgKind(iota) 18 | CANCELLED = MsgKind(iota) 19 | NOT_CANCELLED = MsgKind(iota) 20 | REJECTED = MsgKind(iota) 21 | SHUTDOWN = MsgKind(iota) 22 | NEW_TRADER = MsgKind(iota) 23 | NUM_OF_KIND = int(iota) 24 | ) 25 | 26 | func (k MsgKind) String() string { 27 | switch k { 28 | case NO_KIND: 29 | return "NO_KIND" 30 | case BUY: 31 | return "BUY" 32 | case SELL: 33 | return "SELL" 34 | case CANCEL: 35 | return "CANCEL" 36 | case PARTIAL: 37 | return "PARTIAL" 38 | case FULL: 39 | return "FULL" 40 | case CANCELLED: 41 | return "CANCELLED" 42 | case NOT_CANCELLED: 43 | return "NOT_CANCELLED" 44 | case REJECTED: 45 | return "REJECTED" 46 | case SHUTDOWN: 47 | return "SHUTDOWN" 48 | case NEW_TRADER: 49 | return "NEW_TRADER" 50 | } 51 | panic("Uncreachable") 52 | } 53 | 54 | const ( 55 | // Constant price indicating a market price sell 56 | MARKET_PRICE = 0 57 | ) 58 | 59 | // Flat description of an incoming message 60 | type Message struct { 61 | Kind MsgKind `json:"kind"` 62 | Price uint64 `json:"price"` 63 | Amount uint64 `json:"amount"` 64 | StockId uint64 `json:"stockId"` 65 | TraderId uint32 `json:"traderId"` 66 | TradeId uint32 `json:"tradeId"` 67 | } 68 | 69 | const ( 70 | SizeofMessage = int(unsafe.Sizeof(Message{})) 71 | ) 72 | 73 | func (m *Message) Valid() bool { 74 | if m.Kind == SHUTDOWN { 75 | return m.Price == 0 && m.Amount == 0 && m.TraderId == 0 && m.TradeId == 0 && m.StockId == 0 76 | } 77 | if m.Kind == NEW_TRADER { 78 | return m.TraderId != 0 && m.Price == 0 && m.Amount == 0 && m.TradeId == 0 && m.StockId == 0 79 | } 80 | // Only sells (and messages cancelling sells) are allowed to have a price of 0 81 | isValid := (m.Price != 0 || m.Kind == SELL || m.Kind == CANCEL || m.Kind == CANCELLED || m.Kind == NOT_CANCELLED) 82 | // Remaining fields are never allowed to be 0 83 | isValid = isValid && m.Amount != 0 && m.TraderId != 0 && m.TradeId != 0 && m.StockId != 0 84 | // must have a kind 85 | isValid = isValid && m.Kind != NO_KIND 86 | return isValid 87 | } 88 | 89 | func (m *Message) WriteNewTrader(traderId uint32) { 90 | *m = Message{} 91 | m.Kind = NEW_TRADER 92 | m.TraderId = traderId 93 | } 94 | 95 | func (m *Message) WriteCancelFor(om *Message) { 96 | *m = *om 97 | m.Kind = CANCEL 98 | } 99 | 100 | func (m *Message) String() string { 101 | if m == nil { 102 | return "" 103 | } 104 | price := fstrconv.ItoaDelim(int64(m.Price), ',') 105 | amount := fstrconv.ItoaDelim(int64(m.Amount), ',') 106 | traderId := fstrconv.ItoaDelim(int64(m.TraderId), ' ') 107 | tradeId := fstrconv.ItoaDelim(int64(m.TradeId), ' ') 108 | stockId := fstrconv.ItoaDelim(int64(m.StockId), ' ') 109 | return fmt.Sprintf("%v, price %s, amount %s, trader %s, trade %s, stock %s", m.Kind, price, amount, traderId, tradeId, stockId) 110 | } 111 | -------------------------------------------------------------------------------- /bin/svr/client/server.go.disabled: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fmstephe/matching_engine/coordinator" 6 | "github.com/fmstephe/matching_engine/msg" 7 | ) 8 | 9 | type connect struct { 10 | traderId uint32 11 | orders chan *msg.Message 12 | responses chan *Response 13 | } 14 | 15 | type traderComm struct { 16 | outOfSvr chan *msg.Message 17 | connecter chan connect 18 | } 19 | 20 | type Server struct { 21 | coordinator.AppMsgHelper 22 | intoSvr chan *msg.Message 23 | connecter chan connect 24 | traderMap map[uint32]traderComm 25 | connectsMap map[uint32][]connect 26 | } 27 | 28 | func NewServer() (*Server, *TraderMaker) { 29 | intoSvr := make(chan *msg.Message) 30 | traderMap := make(map[uint32]traderComm) 31 | connecter := make(chan connect) 32 | return &Server{intoSvr: intoSvr, traderMap: traderMap, connecter: connecter}, &TraderMaker{intoSvr: intoSvr, connecter: connecter} 33 | } 34 | 35 | func (s *Server) Run() { 36 | for { 37 | select { 38 | case m := <-s.In: 39 | s.fromServer(m) 40 | case m := <-s.intoSvr: 41 | s.fromTrader(m) 42 | case con := <-s.connecter: 43 | s.connectTrader(con) 44 | } 45 | } 46 | } 47 | 48 | func (s *Server) fromServer(m *msg.Message) { 49 | if m.Kind == msg.SHUTDOWN { 50 | return 51 | } 52 | if m != nil { 53 | cc, ok := s.traderMap[m.TraderId] 54 | if !ok { 55 | println("Missing traderId", m.TraderId) 56 | return 57 | } 58 | cc.outOfSvr <- m 59 | } 60 | } 61 | 62 | func (s *Server) fromTrader(m *msg.Message) { 63 | if m.Kind == msg.NEW_TRADER { 64 | s.newTrader(m) 65 | } else { 66 | s.Out <- m 67 | } 68 | } 69 | 70 | func (s *Server) newTrader(m *msg.Message) { 71 | _, ok := s.traderMap[m.TraderId] 72 | if ok { 73 | println(fmt.Sprintf("Attempted to register a trader (%i) twice", m.TraderId)) 74 | return 75 | } 76 | outOfSvr := make(chan *msg.Message) 77 | t, tc := newTrader(m.TraderId, s.intoSvr, outOfSvr) 78 | go t.run() 79 | s.traderMap[m.TraderId] = tc 80 | cons := s.connectsMap[m.TraderId] 81 | delete(s.connectsMap, m.TraderId) 82 | for _, con := range cons { 83 | tc.connecter <- con 84 | } 85 | } 86 | 87 | func (s *Server) connectTrader(con connect) { 88 | if cc, ok := s.traderMap[con.traderId]; ok { 89 | cc.connecter <- con 90 | } else { 91 | cons := s.connectsMap[con.traderId] 92 | if cons == nil { 93 | cons = make([]connect, 1) 94 | } 95 | cons = append(cons, con) 96 | s.connectsMap[con.traderId] = cons 97 | } 98 | } 99 | 100 | type TraderMaker struct { 101 | intoSvr chan *msg.Message 102 | connecter chan connect 103 | } 104 | 105 | func (tm *TraderMaker) Make(traderId uint32) (orders chan *msg.Message, responses chan *Response) { 106 | m := &msg.Message{} 107 | m.WriteNewTrader(traderId) 108 | tm.intoSvr <- m // Register this user 109 | return tm.Connect(traderId) 110 | } 111 | 112 | func (tm *TraderMaker) Connect(traderId uint32) (orders chan *msg.Message, responses chan *Response) { 113 | orders = make(chan *msg.Message) 114 | responses = make(chan *Response) 115 | con := connect{traderId: traderId, orders: orders, responses: responses} 116 | tm.connecter <- con 117 | return orders, responses 118 | } 119 | -------------------------------------------------------------------------------- /matcher/trefmatcher_test.go: -------------------------------------------------------------------------------- 1 | package matcher 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/coordinator" 5 | "github.com/fmstephe/matching_engine/matcher/pqueue" 6 | "github.com/fmstephe/matching_engine/msg" 7 | ) 8 | 9 | type refmatcher struct { 10 | matchQueues *pqueue.RefMatchQueues 11 | coordinator.AppMsgHelper 12 | } 13 | 14 | func newRefmatcher(lowPrice, highPrice uint64) *refmatcher { 15 | matchQueues := pqueue.NewRefMatchQueues(lowPrice, highPrice) 16 | return &refmatcher{matchQueues: matchQueues} 17 | } 18 | 19 | func (rm *refmatcher) Run() { 20 | m := &msg.Message{} 21 | for { 22 | *m = rm.In.Read() 23 | if m.Kind == msg.SHUTDOWN { 24 | rm.Out.Write(*m) 25 | return 26 | } 27 | if m != nil { 28 | o := &pqueue.OrderNode{} 29 | o.CopyFrom(m) 30 | if o.Kind() == msg.CANCEL { 31 | co := rm.matchQueues.Cancel(o) 32 | if co != nil { 33 | rm.completeCancelled(co) 34 | } 35 | if co == nil { 36 | rm.completeNotCancelled(o) 37 | } 38 | } else { 39 | rm.push(o) 40 | rm.match() 41 | } 42 | } 43 | } 44 | } 45 | 46 | func (rm *refmatcher) push(o *pqueue.OrderNode) { 47 | if o.Kind() == msg.BUY { 48 | rm.matchQueues.PushBuy(o) 49 | return 50 | } 51 | if o.Kind() == msg.SELL { 52 | rm.matchQueues.PushSell(o) 53 | return 54 | } 55 | panic("Unsupported trade kind pushed") 56 | } 57 | 58 | func (rm *refmatcher) match() { 59 | for { 60 | s := rm.matchQueues.PeekSell() 61 | b := rm.matchQueues.PeekBuy() 62 | if s == nil || b == nil { 63 | return 64 | } 65 | if s.Price() > b.Price() { 66 | return 67 | } 68 | if s.Amount() == b.Amount() { 69 | // pop both 70 | rm.matchQueues.PopSell() 71 | rm.matchQueues.PopBuy() 72 | amount := s.Amount() 73 | price := price(b.Price(), s.Price()) 74 | rm.completeTrade(msg.FULL, msg.FULL, b, s, price, amount) 75 | } 76 | if s.Amount() > b.Amount() { 77 | // pop buy 78 | rm.matchQueues.PopBuy() 79 | amount := b.Amount() 80 | price := price(b.Price(), s.Price()) 81 | s.ReduceAmount(b.Amount()) 82 | rm.completeTrade(msg.FULL, msg.PARTIAL, b, s, price, amount) 83 | } 84 | if b.Amount() > s.Amount() { 85 | // pop sell 86 | rm.matchQueues.PopSell() 87 | amount := s.Amount() 88 | price := price(b.Price(), s.Price()) 89 | b.ReduceAmount(s.Amount()) 90 | rm.completeTrade(msg.PARTIAL, msg.FULL, b, s, price, amount) 91 | } 92 | } 93 | } 94 | 95 | func (rm *refmatcher) completeTrade(brk, srk msg.MsgKind, b, s *pqueue.OrderNode, price, amount uint64) { 96 | rm.Out.Write(msg.Message{Kind: brk, Price: price, Amount: amount, TraderId: b.TraderId(), TradeId: b.TradeId(), StockId: b.StockId()}) 97 | rm.Out.Write(msg.Message{Kind: srk, Price: price, Amount: amount, TraderId: s.TraderId(), TradeId: s.TradeId(), StockId: s.StockId()}) 98 | } 99 | 100 | func (rm *refmatcher) completeCancelled(c *pqueue.OrderNode) { 101 | cm := msg.Message{} 102 | c.CopyTo(&cm) 103 | cm.Kind = msg.CANCELLED 104 | rm.Out.Write(cm) 105 | } 106 | 107 | func (rm *refmatcher) completeNotCancelled(nc *pqueue.OrderNode) { 108 | ncm := msg.Message{} 109 | nc.CopyTo(&ncm) 110 | ncm.Kind = msg.NOT_CANCELLED 111 | rm.Out.Write(ncm) 112 | } 113 | -------------------------------------------------------------------------------- /coordinator/inmemory.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fmstephe/matching_engine/msg" 6 | "io" 7 | ) 8 | 9 | func InMemory(reader io.ReadCloser, writer io.WriteCloser, app AppMsgRunner, unused uint32, name string, log bool) { 10 | fromListener, toResponder := InMemoryListenerResponder(reader, writer, name, log) 11 | app.Config(name, fromListener, toResponder) 12 | go app.Run() 13 | } 14 | 15 | func InMemoryListenerResponder(reader io.ReadCloser, writer io.WriteCloser, name string, log bool) (MsgReader, MsgWriter) { 16 | fromListener := NewChanReaderWriter(1000) 17 | toResponder := NewChanReaderWriter(1000) 18 | listener := newInMemoryListener(reader, fromListener, name, log) 19 | responder := newInMemoryResponder(writer, toResponder, name, log) 20 | go listener.Run() 21 | go responder.Run() 22 | return fromListener, toResponder 23 | } 24 | 25 | type inMemoryListener struct { 26 | reader io.ReadCloser 27 | toApp MsgWriter 28 | name string 29 | log bool 30 | } 31 | 32 | func newInMemoryListener(reader io.ReadCloser, toApp MsgWriter, name string, log bool) *inMemoryListener { 33 | l := &inMemoryListener{} 34 | l.reader = reader 35 | l.toApp = toApp 36 | l.name = name 37 | l.log = log 38 | return l 39 | } 40 | 41 | func (l *inMemoryListener) Run() { 42 | defer l.shutdown() 43 | for { 44 | m := l.deserialise() 45 | shutdown := m.Kind == msg.SHUTDOWN 46 | l.toApp.Write(*m) 47 | if shutdown { 48 | return 49 | } 50 | } 51 | } 52 | 53 | func (l *inMemoryListener) deserialise() *msg.Message { 54 | b := make([]byte, msg.ByteSize) 55 | m := &msg.Message{} 56 | n, err := l.reader.Read(b) 57 | if err != nil { 58 | panic("Listener - UDP Read: " + err.Error()) 59 | } else if n != msg.ByteSize { 60 | panic(fmt.Sprintf("Listener: Error incorrect number of bytes. Expecting %d, found %d in %v", msg.ByteSize, n, b)) 61 | } 62 | if err := m.Unmarshal(b[:n]); err != nil { 63 | panic(err.Error()) 64 | } 65 | return m 66 | } 67 | 68 | func (l *inMemoryListener) shutdown() { 69 | l.reader.Close() 70 | } 71 | 72 | type inMemoryResponder struct { 73 | writer io.WriteCloser 74 | fromApp MsgReader 75 | name string 76 | log bool 77 | } 78 | 79 | func newInMemoryResponder(writer io.WriteCloser, fromApp MsgReader, name string, log bool) *inMemoryResponder { 80 | r := &inMemoryResponder{} 81 | r.writer = writer 82 | r.fromApp = fromApp 83 | r.name = name 84 | r.log = log 85 | return r 86 | } 87 | 88 | func (r *inMemoryResponder) Run() { 89 | defer r.shutdown() 90 | m := &msg.Message{} 91 | for { 92 | *m = r.fromApp.Read() 93 | if r.log { 94 | println(r.name + ": " + m.String()) 95 | } 96 | shutdown := m.Kind == msg.SHUTDOWN 97 | r.write(m) 98 | if shutdown { 99 | return 100 | } 101 | } 102 | } 103 | 104 | func (r *inMemoryResponder) write(m *msg.Message) { 105 | b := make([]byte, msg.ByteSize) 106 | if err := m.Marshal(b); err != nil { 107 | panic(err.Error()) 108 | } 109 | n, err := r.writer.Write(b) 110 | if err != nil { 111 | panic(err.Error()) 112 | } 113 | if n != msg.ByteSize { 114 | panic(fmt.Sprintf("Write Error: Wrong sized message. Found %d, expecting %d", n, msg.ByteSize)) 115 | } 116 | } 117 | 118 | func (r *inMemoryResponder) shutdown() { 119 | r.writer.Close() 120 | } 121 | -------------------------------------------------------------------------------- /coordinator/readwrite.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "github.com/fmstephe/flib/fmath" 5 | "github.com/fmstephe/flib/queues/spscq" 6 | "github.com/fmstephe/matching_engine/msg" 7 | "unsafe" 8 | ) 9 | 10 | //TODO need a testsuite for this 11 | 12 | // Is Thread-Safe in a single-reader/single-writer context 13 | type MsgReader interface { 14 | Read() msg.Message 15 | } 16 | 17 | // Is Thread-Safe in a single-reader/single-writer context 18 | type MsgWriter interface { 19 | Write(m msg.Message) 20 | } 21 | 22 | // Is Thread-Safe in a single-reader/single-writer context 23 | type MsgReaderWriter interface { 24 | MsgReader 25 | MsgWriter 26 | } 27 | 28 | // A MsgReader/MsgWriter implementation using channels 29 | type ChanReaderWriter struct { 30 | inout chan msg.Message 31 | } 32 | 33 | func NewChanReaderWriter(size int) *ChanReaderWriter { 34 | inout := make(chan msg.Message, size) 35 | return &ChanReaderWriter{ 36 | inout: inout, 37 | } 38 | } 39 | 40 | func (rw *ChanReaderWriter) Read() msg.Message { 41 | return <-rw.inout 42 | } 43 | 44 | func (rw *ChanReaderWriter) Write(m msg.Message) { 45 | rw.inout <- m 46 | } 47 | 48 | // A MsgReader/MsgWriter implementation using spscq.PointerQ 49 | // Should be fast 50 | type SPSCQReaderWriter struct { 51 | q *spscq.PointerQ 52 | } 53 | 54 | func NewSPSCQReaderWriter(size int64) *SPSCQReaderWriter { 55 | p2Size := fmath.NxtPowerOfTwo(size) 56 | q, err := spscq.NewPointerQ(p2Size, 1024) 57 | if err != nil { 58 | panic(err) 59 | } 60 | return &SPSCQReaderWriter{ 61 | q: q, 62 | } 63 | } 64 | 65 | func (rw *SPSCQReaderWriter) Read() msg.Message { 66 | return *((*msg.Message)(rw.q.ReadSingleBlocking())) 67 | } 68 | 69 | func (rw *SPSCQReaderWriter) Write(m msg.Message) { 70 | rw.q.WriteSingleBlocking(unsafe.Pointer(&m)) 71 | } 72 | 73 | func (rw *SPSCQReaderWriter) Fails() (int64, int64) { 74 | return rw.q.FailedReads(), rw.q.FailedWrites() 75 | } 76 | 77 | // A Reader loaded with a set of test *msg.Message 78 | // When the messages slice is exhausted Read() returns 79 | // SHUTDOWN messages 80 | type PreloadedReaderWriter struct { 81 | idx int 82 | ms []msg.Message 83 | } 84 | 85 | func NewPreloadedReaderWriter(ms []msg.Message) *PreloadedReaderWriter { 86 | return &PreloadedReaderWriter{ 87 | ms: ms, 88 | } 89 | } 90 | 91 | func (r *PreloadedReaderWriter) Read() msg.Message { 92 | if r.idx >= len(r.ms) { 93 | return msg.Message{Kind: msg.SHUTDOWN} 94 | } else { 95 | m := r.ms[r.idx] 96 | r.idx++ 97 | return m 98 | } 99 | } 100 | 101 | func (r *PreloadedReaderWriter) Write(m msg.Message) { 102 | } 103 | 104 | // A Writer which does almost nothing. Good for some performance testing. 105 | type ShutdownReaderWriter struct { 106 | out chan msg.Message 107 | } 108 | 109 | func NewShutdownReaderWriter() *ShutdownReaderWriter { 110 | return &ShutdownReaderWriter{ 111 | out: make(chan msg.Message, 1), 112 | } 113 | } 114 | 115 | func (rw *ShutdownReaderWriter) Read() msg.Message { 116 | return <-rw.out 117 | } 118 | 119 | func (rw *ShutdownReaderWriter) Write(m msg.Message) { 120 | if m.Kind == msg.SHUTDOWN { 121 | rw.out <- m 122 | } 123 | } 124 | 125 | // A Writer which does absolutely nothing. Good for single threaded performance testing. 126 | type NoopReaderWriter struct { 127 | } 128 | 129 | func NewNoopReaderWriter() *NoopReaderWriter { 130 | return &NoopReaderWriter{} 131 | } 132 | 133 | func (rw *NoopReaderWriter) Read() msg.Message { 134 | return msg.Message{} 135 | } 136 | 137 | func (rw *NoopReaderWriter) Write(m msg.Message) { 138 | } 139 | -------------------------------------------------------------------------------- /msg/maker.go: -------------------------------------------------------------------------------- 1 | package msg 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | ) 8 | 9 | type MessageMaker struct { 10 | traderId uint32 11 | r *rand.Rand 12 | } 13 | 14 | func NewMessageMaker(initTraderId uint32) *MessageMaker { 15 | r := rand.New(rand.NewSource(1)) 16 | return &MessageMaker{traderId: initTraderId, r: r} 17 | } 18 | 19 | func (mm *MessageMaker) Seed(seed int64) { 20 | mm.r.Seed(seed) 21 | } 22 | 23 | func (mm *MessageMaker) Between(lower, upper uint64) uint64 { 24 | if lower > upper { 25 | panic(fmt.Sprintf("lower must be less than upper. lower was %d, upper was %d", lower, upper)) 26 | } 27 | if int64(lower) < 0 || int64(upper) < 0 { 28 | panic(fmt.Sprintf("lower and higher must be <= 2^63. lower was %d, upper was %d", lower, upper)) 29 | } 30 | if lower == upper { 31 | return lower 32 | } 33 | low, up := int64(lower), int64(upper) 34 | return uint64(mm.r.Int63n(up-low) + low) 35 | } 36 | 37 | func (mm *MessageMaker) MkPricedOrder(price uint64, kind MsgKind) *Message { 38 | m := &Message{} 39 | mm.writePricedOrder(price, kind, m) 40 | return m 41 | } 42 | 43 | func (mm *MessageMaker) writePricedOrder(price uint64, kind MsgKind, m *Message) { 44 | mm.traderId++ 45 | *m = Message{Price: price, Amount: 1, TraderId: mm.traderId, TradeId: 1, StockId: 1} 46 | m.Kind = kind 47 | } 48 | 49 | func (mm *MessageMaker) ValRangePyramid(n int, low, high uint64) []uint64 { 50 | seq := (high - low) / 4 51 | vals := make([]uint64, n) 52 | for i := 0; i < n; i++ { 53 | val := mm.Between(0, seq) + mm.Between(0, seq) + mm.Between(0, seq) + mm.Between(0, seq) 54 | vals[i] = uint64(val) + low 55 | } 56 | return vals 57 | } 58 | 59 | func (mm *MessageMaker) ValRangeFlat(n int, low, high uint64) []uint64 { 60 | vals := make([]uint64, n) 61 | for i := 0; i < n; i++ { 62 | vals[i] = mm.Between(low, high) 63 | } 64 | return vals 65 | } 66 | 67 | func (mm *MessageMaker) MkBuys(prices []uint64, stockId uint64) []Message { 68 | return mm.MkOrders(prices, stockId, BUY) 69 | } 70 | 71 | func (mm *MessageMaker) MkSells(prices []uint64, stockId uint64) []Message { 72 | return mm.MkOrders(prices, stockId, SELL) 73 | } 74 | 75 | func (mm *MessageMaker) MkOrders(prices []uint64, stockId uint64, kind MsgKind) []Message { 76 | msgs := make([]Message, len(prices)) 77 | for i, price := range prices { 78 | mm.traderId++ 79 | msgs[i] = Message{Price: price, Amount: 1, TraderId: mm.traderId, TradeId: uint32(i + 1), StockId: stockId} 80 | msgs[i].Kind = kind 81 | } 82 | return msgs 83 | } 84 | 85 | func (mm *MessageMaker) RndTradeSet(size, depth int, low, high uint64) ([]Message, error) { 86 | if depth > size { 87 | return nil, errors.New(fmt.Sprintf("Size (%d) must be greater than or equal to (%d)", size, depth)) 88 | } 89 | orders := make([]Message, size*4) 90 | buys := make([]*Message, 0, size) 91 | sells := make([]*Message, 0, size) 92 | idx := 0 93 | for i := 0; i < size+depth; i++ { 94 | if i < size { 95 | b := &orders[idx] 96 | idx++ 97 | mm.writePricedOrder(mm.Between(low, high), BUY, b) 98 | buys = append(buys, b) 99 | if b.Price == 0 { 100 | b.Price = 1 // Buys can't have price of 0 101 | } 102 | s := &orders[idx] 103 | idx++ 104 | mm.writePricedOrder(mm.Between(low, high), SELL, s) 105 | sells = append(sells, s) 106 | } 107 | if i >= depth { 108 | b := buys[i-depth] 109 | cb := &orders[idx] 110 | idx++ 111 | cb.WriteCancelFor(b) 112 | s := sells[i-depth] 113 | cs := &orders[idx] 114 | idx++ 115 | cs.WriteCancelFor(s) 116 | } 117 | } 118 | return orders, nil 119 | } 120 | -------------------------------------------------------------------------------- /bin/itchdebug/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | println("itchdebug not supported at this time.") 5 | } 6 | 7 | /* 8 | 9 | import ( 10 | "bufio" 11 | "bytes" 12 | "flag" 13 | "fmt" 14 | "github.com/fmstephe/fstrconv" 15 | "github.com/fmstephe/matching_engine/matcher" 16 | "github.com/fmstephe/matching_engine/msg" 17 | "github.com/fmstephe/matching_engine/itch" 18 | "os" 19 | ) 20 | 21 | var ( 22 | filePath = flag.String("f", "", "Relative path to an ITCH file to read") 23 | mode = flag.String("m", "step", "Running mode. Currently supporting 'step', 'exec' and 'list'") 24 | line = flag.Uint("l", 0, "First line to break on. Mode is ignored until line l is reached, then normal excution continues") 25 | ) 26 | 27 | func main() { 28 | flag.Parse() 29 | loop() 30 | } 31 | 32 | func loop() { 33 | l := *line 34 | for { 35 | ir := itch.NewItchReader(*filePath) 36 | defer func() { 37 | if r := recover(); r != nil { 38 | println(fmt.Sprintf("Panic at line %d", ir.LineCount())) 39 | print(fmt.Sprintf("%#v", r)) 40 | panic("Repanic") 41 | } 42 | }() 43 | in := bufio.NewReader(os.Stdin) 44 | dispatch := make(chan interface{}, 20) 45 | orders := make(chan *msg.Message) 46 | m := matcher.NewMatcher(1000) 47 | m.SetDispatch(dispatch) 48 | m.SetAppMsgs(appMsgs) 49 | go m.Run() 50 | // 51 | var o *msg.OrderNode 52 | var err error 53 | for { 54 | o, _, err = ir.ReadOrderNode() 55 | if err != nil { 56 | panic(err) 57 | } 58 | if o != nil && (o.Kind() == msg.BUY || o.Kind() == msg.SELL || o.Kind() == msg.CANCEL) { 59 | orders<-o 60 | clear(dispatch) 61 | } 62 | checkPrint(ir, o, m, l) 63 | c := checkPause(in, ir, o, l) 64 | if c == 'k' { 65 | l = ir.LineCount() - 1 66 | break 67 | } 68 | } 69 | } 70 | } 71 | 72 | func checkPause(in *bufio.Reader, ir *ItchReader, o *msg.OrderNode, bLine uint) byte { 73 | if bLine > ir.LineCount() { 74 | return 'z' 75 | } 76 | if *mode == "step" { 77 | return pause(in) 78 | } 79 | if *mode == "exec" && o == nil { 80 | return pause(in) 81 | } 82 | return 'z' 83 | } 84 | 85 | func pause(in *bufio.Reader) byte { 86 | c, err := in.ReadByte() 87 | if err != nil { 88 | println(err.Error()) 89 | os.Exit(1) 90 | } 91 | return c 92 | } 93 | 94 | func checkPrint(ir *ItchReader, o *msg.OrderNode, m *matcher.M, bLine uint) { 95 | if bLine > ir.LineCount() { 96 | return 97 | } 98 | if *mode == "step" { 99 | printInfo(ir, o, m) 100 | } 101 | if *mode == "exec" && o == nil { 102 | printInfo(ir, o, m) 103 | } 104 | } 105 | 106 | func printInfo(ir *ItchReader, o *msg.OrderNode, m *matcher.M) { 107 | buys, sells, orders, executions := m.Survey() 108 | println("OrderNode ", o.String()) 109 | println("Line ", ir.LineCount()) 110 | println("Max Buy ", fstrconv.Itoa64Delim(int64(ir.MaxBuy()), ',')) 111 | println("Min Sell ", fstrconv.Itoa64Delim(int64(ir.MinSell()), ',')) 112 | println("Executions ", executions) 113 | println("...") 114 | println("Total ", orders.Size()) 115 | println("Buy Limits ", formatLimits(buys)) 116 | println("Sell Limits ", formatLimits(sells)) 117 | println() 118 | } 119 | 120 | func formatLimits(limits []*msg.SurveyLimit) string { 121 | var b bytes.Buffer 122 | for _, l := range limits { 123 | b.WriteString(fmt.Sprintf("(%s, %s)", fstrconv.Itoa64Delim(int64(l.Price), ','), fstrconv.Itoa64Delim(int64(l.Size), ','))) 124 | b.WriteString(", ") 125 | } 126 | return b.String() 127 | } 128 | 129 | func drain(c chan interface{}) { 130 | for len(c) > 0 { 131 | <-c 132 | } 133 | } 134 | */ 135 | -------------------------------------------------------------------------------- /bin/perf/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "github.com/fmstephe/flib/fstrconv" 6 | "github.com/fmstephe/matching_engine/coordinator" 7 | "github.com/fmstephe/matching_engine/matcher" 8 | "github.com/fmstephe/matching_engine/msg" 9 | "log" 10 | "math/rand" 11 | "os" 12 | "runtime/pprof" 13 | "time" 14 | ) 15 | 16 | const ( 17 | StockId = uint32(1) 18 | ) 19 | 20 | var ( 21 | filePath = flag.String("f", "", "Relative path to an ITCH file providing test data") 22 | profile = flag.String("p", "", "Write out a profile of this application, 'cpu' and 'mem' supported") 23 | orderNum = flag.Int("o", 1, "The number of orders to generate (in millions). Ignored if -f is provided") 24 | delDelay = flag.Int("d", 10, "The number of orders generated before we begin deleting existing orders") 25 | perfRand = rand.New(rand.NewSource(1)) 26 | orderMaker = msg.NewMessageMaker(1) 27 | ) 28 | 29 | func main() { 30 | doPerf(true) 31 | } 32 | 33 | func doPerf(log bool) { 34 | flag.Parse() 35 | data := getData() 36 | if log { 37 | orderCount := fstrconv.ItoaComma(int64(len(data))) 38 | println(orderCount, "OrderNodes Built") 39 | } 40 | start := time.Now() 41 | defer func() { 42 | if log { 43 | println("Running Time: ", time.Now().Sub(start).String()) 44 | } 45 | }() 46 | startProfile() 47 | defer endProfile() 48 | singleThreaded(log, data) 49 | } 50 | 51 | func singleThreaded(log bool, data []msg.Message) { 52 | inout := coordinator.NewNoopReaderWriter() 53 | mchr := matcher.NewMatcher(*delDelay * 2) 54 | mchr.Config("Perf Matcher", inout, inout) 55 | for i := range data { 56 | mchr.Submit(&data[i]) 57 | } 58 | } 59 | 60 | func multiThreadedChan(log bool, data []msg.Message) { 61 | in := coordinator.NewChanReaderWriter(1024) 62 | out := coordinator.NewChanReaderWriter(1024) 63 | multiThreaded(log, data, in, out) 64 | } 65 | 66 | func multiThreadedPreloaded(log bool, data []msg.Message) { 67 | in := coordinator.NewPreloadedReaderWriter(data) 68 | out := coordinator.NewShutdownReaderWriter() 69 | multiThreaded(log, data, in, out) 70 | } 71 | 72 | func multiThreadedSPSCQ(log bool, data []msg.Message) { 73 | in := coordinator.NewSPSCQReaderWriter(1024 * 1024) 74 | out := coordinator.NewSPSCQReaderWriter(1024 * 1024) 75 | multiThreaded(log, data, in, out) 76 | } 77 | 78 | func multiThreaded(log bool, data []msg.Message, in, out coordinator.MsgReaderWriter) { 79 | mchr := matcher.NewMatcher(*delDelay * 2) 80 | mchr.Config("Perf Matcher", in, out) 81 | go run(mchr) 82 | go write(in, data) 83 | // Read all messages coming out of the matching engine 84 | read(out) 85 | } 86 | 87 | func read(reader coordinator.MsgReader) { 88 | for { 89 | m := reader.Read() 90 | if m.Kind == msg.SHUTDOWN { 91 | break 92 | } 93 | } 94 | } 95 | 96 | type runner interface { 97 | Run() 98 | } 99 | 100 | func run(r runner) { 101 | r.Run() 102 | } 103 | 104 | func write(in coordinator.MsgWriter, msgs []msg.Message) { 105 | for i := range msgs { 106 | in.Write(msgs[i]) 107 | } 108 | in.Write(msg.Message{Kind: msg.SHUTDOWN}) 109 | } 110 | 111 | func startProfile() { 112 | if *profile == "cpu" { 113 | f, err := os.Create("cpu.prof") 114 | if err != nil { 115 | log.Fatal(err) 116 | } 117 | pprof.StartCPUProfile(f) 118 | } 119 | } 120 | 121 | func endProfile() { 122 | if *profile == "cpu" { 123 | pprof.StopCPUProfile() 124 | } 125 | if *profile == "mem" { 126 | f, err := os.Create("mem.prof") 127 | if err != nil { 128 | log.Fatal(err) 129 | } 130 | pprof.WriteHeapProfile(f) 131 | } 132 | } 133 | 134 | func getData() []msg.Message { 135 | orders, err := orderMaker.RndTradeSet(*orderNum*1000*1000, *delDelay, 1000, 1500) 136 | if err != nil { 137 | panic(err.Error()) 138 | } 139 | return orders 140 | } 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # matching_engine 2 | =============== 3 | 4 | A simple financial trading matching engine. Built to learn more about how they work. 5 | 6 | This system is not a complete implementation of a full matching engine system. It contains a basic set of internal components for a matching engine but doesn't include any networking components which would be required to actually use it in production. 7 | 8 | ## msg 9 | 10 | The `msg` package defines the basic messages which are processed by a matching engine. The system was designed to be _entirely_ driven by these messages, so we find both the expected buy/sell limit-order type messages as well as new-trader and shutdown messages which change the state of the matching engine without actually engaging it trading activity. 11 | 12 | The rationale behind this choice was that the state of a matching engine should be determined _exclusively_ by the serise of messages which have passed through it. This allows us to perfectly reproduce any matching engine state simply by replaying these messages. 13 | 14 | There is some initial work done on a binary serialisation format. This was intended to be used for storing messages in logs, and for sending over the network. I don't think the approach taken here was very well conceived, and I would probably use an efficient encoding library if I was to work on this now. 15 | 16 | The `maker.go` file is interesting in that it generates random sets of buy/sell messages. While the code-base has many traditional run-code/assert-outcome style unit tests, we also employed a lot of system-invariant style tests with randomly generated test data. 17 | 18 | ## matcher/pqueue 19 | 20 | This is a priority queue implementation custom built to support the matching engine. This is definitely the most complex piece of code in the system. The pqueue is a pair of linked red-black trees. The two trees share the nodes. Specifically an `OrderNode` is a member of two red-black trees, one tree is ordered by price and the other is ordered by the id of the order. A feature of having a single order object linked to two trees is that when we remove an `OrderNode` from one tree we can remove it from the other, without having to perform a search through the tree to find it. 21 | 22 | The red-black tree is an internal detail the publicly exposed type is the `pqueue.MatchQueues` this is a trio of red-black trees, one for buys and one for sells and one for all orders sorted by guid. This exposes the operations needed to submit a buy, or sell order as well as cancel an existing order (using its guid). 23 | 24 | The testing approach used here is primarily invariant based testing. We perform operations on the rb-trees and then test to ensure that the structural invariants of the trees has not been violated. In the development of these trees I first started with a body of hand-written operations, followed with assertions on the result. Even after writing a very large number of unit tests I was not able to find any bugs. However, when I built the randomly generated invariant based testing I was able to find a number of subtle bugs and fix them 25 | 26 | If I did this now I would have kept the 'normal' style unit tests as well as the more complex invariant style tests. Although invariant testing was more valuable in finding bugs, the 'normal' style of unit tests are easy to read and are a useful form of documentation of the expected behaviour of the rb-trees. 27 | 28 | ## matcher 29 | 30 | The matcher implements an actual matching engine. This uses a `pqueue.MatchQueues` to manage incoming orders. As each new order comes in an attempt is made to match the order, buy or sell, and the resulting matches are written to the output. Cancelling orders is supported, as-is shutting down the order book. 31 | 32 | ## coordinator 33 | 34 | This package is designed to allow us to wrap a `matcher.M` with an input and output queue. There are two implementations available, one which uses a Go channel and one which uses an imported high performance queue. The queue imported is from another project I authored which can be found at `github.com/fmstephe/flib`. 35 | 36 | I would not use this approach if I was building this system again today. I think that the choice to make the `matcher.M` struct embed the `coordinator.AppMsgHelper` interface is unnecessarily complicated. 37 | -------------------------------------------------------------------------------- /bin/svr/html/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 122 | 123 | 124 | 125 | 126 |
127 | Price:
128 | Amount:
129 | StockId:
130 |
131 | 132 | 133 |

134 |

135 |

136 |

137 |

138 |

139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /msg/tmsg_test.go: -------------------------------------------------------------------------------- 1 | package msg 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | var LOCALHOST = [4]byte{127, 0, 0, 1} 8 | 9 | var ( 10 | // A full message with every field set 11 | fullMessage = Message{Price: 1, Amount: 1, TraderId: 1, TradeId: 1, StockId: 1} 12 | // A full message but with the price set to 0, i.e. a sell that matches any buy price 13 | openSellMessage = Message{Price: 0, Amount: 1, TraderId: 1, TradeId: 1, StockId: 1} 14 | // Collection of messages with misssing fields (skipping price) 15 | partialBodyMessages = []Message{ 16 | {Price: 1, Amount: 0, TraderId: 1, TradeId: 1, StockId: 1}, 17 | {Price: 1, Amount: 1, TraderId: 0, TradeId: 1, StockId: 1}, 18 | {Price: 1, Amount: 1, TraderId: 1, TradeId: 0, StockId: 1}, 19 | {Price: 1, Amount: 1, TraderId: 1, TradeId: 1, StockId: 0}, 20 | } 21 | // no fields set at all 22 | blankMessage = Message{} 23 | ) 24 | 25 | func testFullAndOpenSell(t *testing.T, f func(Message) Message, full, openSell bool) { 26 | testAllCategories(t, f, full, openSell, false, false) 27 | } 28 | 29 | func testAllCategories(t *testing.T, f func(Message) Message, full, openSell, partialBody, blank bool) { 30 | m := f(fullMessage) 31 | expect(t, full, m) 32 | m = f(openSellMessage) 33 | expect(t, openSell, m) 34 | for _, p := range partialBodyMessages { 35 | m = f(p) 36 | expect(t, partialBody, m) 37 | } 38 | m = f(blankMessage) 39 | expect(t, blank, m) 40 | } 41 | 42 | func expect(t *testing.T, isValid bool, m Message) { 43 | if isValid != m.Valid() { 44 | if isValid { 45 | t.Errorf("\nExpected valid\n%v", &m) 46 | } else { 47 | t.Errorf("\nExpected invalid\n%v", &m) 48 | } 49 | } 50 | } 51 | 52 | func TestKindlessMessages(t *testing.T) { 53 | f := func(m Message) Message { 54 | m.Kind = NO_KIND 55 | return m 56 | } 57 | testFullAndOpenSell(t, f, false, false) 58 | } 59 | 60 | func TestZeroedMessage(t *testing.T) { 61 | expect(t, false, Message{}) 62 | } 63 | 64 | func TestWriteBuy(t *testing.T) { 65 | f := func(m Message) Message { 66 | m.Kind = BUY 67 | return m 68 | } 69 | testFullAndOpenSell(t, f, true, false) 70 | } 71 | 72 | func TestWriteSell(t *testing.T) { 73 | f := func(m Message) Message { 74 | m.Kind = SELL 75 | return m 76 | } 77 | testFullAndOpenSell(t, f, true, true) 78 | } 79 | 80 | func TestWriteCancelFor(t *testing.T) { 81 | f := func(m Message) Message { 82 | cm := Message{} 83 | cm.WriteCancelFor(&m) 84 | return cm 85 | } 86 | testFullAndOpenSell(t, f, true, true) 87 | } 88 | 89 | func TestWriteApp(t *testing.T) { 90 | // can't write no_kind app 91 | f := func(m Message) Message { 92 | m.Kind = NO_KIND 93 | return m 94 | } 95 | testFullAndOpenSell(t, f, false, false) 96 | // can write buy app 97 | f = func(m Message) Message { 98 | m.Kind = BUY 99 | return m 100 | } 101 | testFullAndOpenSell(t, f, true, false) 102 | // can write sell app 103 | f = func(m Message) Message { 104 | m.Kind = SELL 105 | return m 106 | } 107 | testFullAndOpenSell(t, f, true, true) 108 | // can write cancel app 109 | f = func(m Message) Message { 110 | m.Kind = CANCEL 111 | return m 112 | } 113 | testFullAndOpenSell(t, f, true, true) 114 | // can write partial app 115 | f = func(m Message) Message { 116 | m.Kind = PARTIAL 117 | return m 118 | } 119 | testFullAndOpenSell(t, f, true, false) 120 | // can write full app 121 | f = func(m Message) Message { 122 | m.Kind = FULL 123 | return m 124 | } 125 | testFullAndOpenSell(t, f, true, false) 126 | // can write cancelled app 127 | f = func(m Message) Message { 128 | m.Kind = CANCELLED 129 | return m 130 | } 131 | testFullAndOpenSell(t, f, true, true) 132 | // can write not_cancelled app 133 | f = func(m Message) Message { 134 | m.Kind = NOT_CANCELLED 135 | return m 136 | } 137 | testFullAndOpenSell(t, f, true, true) 138 | } 139 | 140 | func TestWriteCancelled(t *testing.T) { 141 | // Can cancel standard message and open sell 142 | f := func(m Message) Message { 143 | m.Kind = CANCELLED 144 | return m 145 | } 146 | testFullAndOpenSell(t, f, true, true) 147 | } 148 | 149 | func TestWriteNotCancelled(t *testing.T) { 150 | // Can not_cancel standard message and open sell 151 | f := func(m Message) Message { 152 | m.Kind = NOT_CANCELLED 153 | return m 154 | } 155 | testFullAndOpenSell(t, f, true, true) 156 | } 157 | -------------------------------------------------------------------------------- /matcher/matcher.go: -------------------------------------------------------------------------------- 1 | package matcher 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fmstephe/matching_engine/coordinator" 6 | "github.com/fmstephe/matching_engine/matcher/pqueue" 7 | "github.com/fmstephe/matching_engine/msg" 8 | ) 9 | 10 | type M struct { 11 | coordinator.AppMsgHelper 12 | matchQueues map[uint64]*pqueue.MatchQueues 13 | slab *pqueue.Slab 14 | } 15 | 16 | func NewMatcher(slabSize int) *M { 17 | matchQueues := make(map[uint64]*pqueue.MatchQueues) 18 | slab := pqueue.NewSlab(slabSize) 19 | return &M{matchQueues: matchQueues, slab: slab} 20 | } 21 | 22 | func (m *M) Run() { 23 | o := &msg.Message{} 24 | for { 25 | *o = m.In.Read() 26 | if o.Kind == msg.SHUTDOWN { 27 | m.Out.Write(*o) 28 | return 29 | } 30 | m.Submit(o) 31 | } 32 | } 33 | 34 | func (m *M) Submit(o *msg.Message) { 35 | on := m.slab.Malloc() 36 | on.CopyFrom(o) 37 | switch on.Kind() { 38 | case msg.BUY: 39 | m.addBuy(on) 40 | case msg.SELL: 41 | m.addSell(on) 42 | case msg.CANCEL: 43 | m.cancel(on) 44 | default: 45 | panic(fmt.Sprintf("MsgKind %v not supported", on)) 46 | } 47 | } 48 | 49 | func (m *M) getMatchQueues(stockId uint64) *pqueue.MatchQueues { 50 | q := m.matchQueues[stockId] 51 | if q == nil { 52 | q = &pqueue.MatchQueues{} 53 | m.matchQueues[stockId] = q 54 | } 55 | return q 56 | } 57 | 58 | func (m *M) addBuy(b *pqueue.OrderNode) { 59 | if b.Price() == msg.MARKET_PRICE { 60 | panic("It is illegal to send a buy at market price") 61 | } 62 | q := m.getMatchQueues(b.StockId()) 63 | if !m.fillableBuy(b, q) { 64 | q.PushBuy(b) 65 | } 66 | } 67 | 68 | func (m *M) addSell(s *pqueue.OrderNode) { 69 | q := m.getMatchQueues(s.StockId()) 70 | if !m.fillableSell(s, q) { 71 | q.PushSell(s) 72 | } 73 | } 74 | 75 | func (m *M) cancel(o *pqueue.OrderNode) { 76 | q := m.getMatchQueues(o.StockId()) 77 | ro := q.Cancel(o) 78 | if ro != nil { 79 | m.completeCancelled(ro) 80 | m.slab.Free(ro) 81 | } else { 82 | m.completeNotCancelled(o) 83 | } 84 | m.slab.Free(o) 85 | } 86 | 87 | func (m *M) fillableBuy(b *pqueue.OrderNode, q *pqueue.MatchQueues) bool { 88 | for { 89 | s := q.PeekSell() 90 | if s == nil { 91 | return false 92 | } 93 | if b.Price() >= s.Price() { 94 | if b.Amount() > s.Amount() { 95 | amount := s.Amount() 96 | price := price(b.Price(), s.Price()) 97 | s.Remove() 98 | m.slab.Free(s) 99 | b.ReduceAmount(amount) 100 | m.completeTrade(msg.PARTIAL, msg.FULL, b, s, price, amount) 101 | continue // The sell has been used up 102 | } 103 | if s.Amount() > b.Amount() { 104 | amount := b.Amount() 105 | price := price(b.Price(), s.Price()) 106 | s.ReduceAmount(amount) 107 | m.completeTrade(msg.FULL, msg.PARTIAL, b, s, price, amount) 108 | m.slab.Free(b) 109 | return true // The buy has been used up 110 | } 111 | if s.Amount() == b.Amount() { 112 | amount := b.Amount() 113 | price := price(b.Price(), s.Price()) 114 | m.completeTrade(msg.FULL, msg.FULL, b, s, price, amount) 115 | s.Remove() 116 | m.slab.Free(s) 117 | m.slab.Free(b) 118 | return true // The buy and sell have been used up 119 | } 120 | } else { 121 | return false 122 | } 123 | } 124 | } 125 | 126 | func (m *M) fillableSell(s *pqueue.OrderNode, q *pqueue.MatchQueues) bool { 127 | for { 128 | b := q.PeekBuy() 129 | if b == nil { 130 | return false 131 | } 132 | if b.Price() >= s.Price() { 133 | if b.Amount() > s.Amount() { 134 | amount := s.Amount() 135 | price := price(b.Price(), s.Price()) 136 | b.ReduceAmount(amount) 137 | m.completeTrade(msg.PARTIAL, msg.FULL, b, s, price, amount) 138 | s.Remove() 139 | m.slab.Free(s) 140 | return true // The sell has been used up 141 | } 142 | if s.Amount() > b.Amount() { 143 | amount := b.Amount() 144 | price := price(b.Price(), s.Price()) 145 | s.ReduceAmount(amount) 146 | m.completeTrade(msg.FULL, msg.PARTIAL, b, s, price, amount) 147 | b.Remove() 148 | m.slab.Free(b) // The buy has been used up 149 | continue 150 | } 151 | if s.Amount() == b.Amount() { 152 | amount := b.Amount() 153 | price := price(b.Price(), s.Price()) 154 | m.completeTrade(msg.FULL, msg.FULL, b, s, price, amount) 155 | b.Remove() 156 | m.slab.Free(b) 157 | m.slab.Free(s) 158 | return true // The sell and buy have been used up 159 | } 160 | } else { 161 | return false 162 | } 163 | } 164 | } 165 | 166 | func price(bPrice, sPrice uint64) uint64 { 167 | if sPrice == msg.MARKET_PRICE { 168 | return bPrice 169 | } 170 | d := bPrice - sPrice 171 | return sPrice + (d / 2) 172 | } 173 | 174 | func (m *M) completeTrade(brk, srk msg.MsgKind, b, s *pqueue.OrderNode, price, amount uint64) { 175 | m.Out.Write(msg.Message{Kind: brk, Price: price, Amount: amount, TraderId: b.TraderId(), TradeId: b.TradeId(), StockId: b.StockId()}) 176 | m.Out.Write(msg.Message{Kind: srk, Price: price, Amount: amount, TraderId: s.TraderId(), TradeId: s.TradeId(), StockId: s.StockId()}) 177 | } 178 | 179 | func (m *M) completeCancelled(c *pqueue.OrderNode) { 180 | cm := msg.Message{} 181 | c.CopyTo(&cm) 182 | cm.Kind = msg.CANCELLED 183 | m.Out.Write(cm) 184 | } 185 | 186 | func (m *M) completeNotCancelled(nc *pqueue.OrderNode) { 187 | ncm := msg.Message{} 188 | nc.CopyTo(&ncm) 189 | ncm.Kind = msg.NOT_CANCELLED 190 | m.Out.Write(ncm) 191 | } 192 | -------------------------------------------------------------------------------- /bin/svr/client/balanceManager.go.disabled: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | "strconv" 6 | ) 7 | 8 | type balanceManager struct { 9 | traderId uint32 10 | curTradeId uint32 11 | outstanding []msg.Message 12 | current uint64 13 | available uint64 14 | held map[uint64]uint64 15 | toSell map[uint64]uint64 16 | } 17 | 18 | func newBalanceManager(traderId uint32, balance uint64, initialStocks map[uint64]uint64) *balanceManager { 19 | bm := &balanceManager{current: balance, available: balance} 20 | bm.traderId = traderId 21 | bm.curTradeId = 0 22 | bm.held = make(map[uint64]uint64) 23 | bm.toSell = make(map[uint64]uint64) 24 | for k, v := range initialStocks { 25 | bm.held[k] = v 26 | } 27 | return bm 28 | } 29 | 30 | // NB: After this method returns BUYs and SELLs are guaranteed to have the correct TradeId 31 | // BUYs, SELLs and CANCELs are guaranteed to have the correct TraderId 32 | // CANCELLEDs and FULLs are assumed to have the correct values and are unchanged 33 | func (bm *balanceManager) process(m *msg.Message) bool { 34 | switch m.Kind { 35 | case msg.CANCEL: 36 | m.TraderId = bm.traderId 37 | return bm.processCancel(m) 38 | case msg.BUY: 39 | m.TraderId = bm.traderId 40 | bm.curTradeId++ 41 | m.TradeId = bm.curTradeId 42 | return bm.processBuy(m) 43 | case msg.SELL: 44 | m.TraderId = bm.traderId 45 | bm.curTradeId++ 46 | m.TradeId = bm.curTradeId 47 | return bm.processSell(m) 48 | case msg.CANCELLED: 49 | return bm.cancelOutstanding(m) 50 | case msg.FULL, msg.PARTIAL: 51 | return bm.matchOutstanding(m) 52 | } 53 | return false 54 | } 55 | 56 | func (bm *balanceManager) processCancel(m *msg.Message) bool { 57 | bm.outstanding = append(bm.outstanding, *m) 58 | return true 59 | } 60 | 61 | func (bm *balanceManager) processBuy(m *msg.Message) bool { 62 | if bm.available < (m.Price * m.Amount) { 63 | return false 64 | } 65 | bm.available -= m.Price * m.Amount 66 | bm.outstanding = append(bm.outstanding, *m) 67 | return true 68 | } 69 | 70 | func (bm *balanceManager) processSell(m *msg.Message) bool { 71 | if bm.held[m.StockId] < m.Amount { 72 | return false 73 | } 74 | bm.addHeld(m.StockId, -m.Amount) 75 | bm.addToSell(m.StockId, m.Amount) 76 | bm.outstanding = append(bm.outstanding, *m) 77 | // Don't clean up, we want the zeroed held stocks to remain 78 | return true 79 | } 80 | 81 | func (bm *balanceManager) cancelOutstanding(m *msg.Message) bool { 82 | accepted := false 83 | newOutstanding := make([]msg.Message, 0, len(bm.outstanding)) 84 | for _, om := range bm.outstanding { 85 | if om.TradeId != m.TradeId { 86 | newOutstanding = append(newOutstanding, om) 87 | } else { 88 | accepted = true 89 | switch om.Kind { 90 | case msg.BUY: 91 | bm.cancelBuy(m.Price, m.Amount) 92 | case msg.SELL: 93 | bm.cancelSell(m.StockId, m.Amount) 94 | } 95 | } 96 | } 97 | bm.outstanding = newOutstanding 98 | return accepted 99 | } 100 | 101 | func (bm *balanceManager) matchOutstanding(m *msg.Message) bool { 102 | accepted := false 103 | newOutstanding := make([]msg.Message, 0, len(bm.outstanding)) 104 | for i, om := range bm.outstanding { 105 | if om.TradeId != m.TradeId { 106 | newOutstanding = append(newOutstanding, om) 107 | } else { 108 | accepted = true 109 | if m.Kind == msg.PARTIAL { 110 | newOutstanding = append(newOutstanding, om) 111 | newOutstanding[i].Amount -= m.Amount 112 | } 113 | if om.Kind == msg.SELL { 114 | bm.completeSell(m.StockId, m.Price, m.Amount) 115 | } 116 | if om.Kind == msg.BUY { 117 | bm.completeBuy(m.StockId, om.Price, m.Price, m.Amount) 118 | } 119 | } 120 | } 121 | bm.outstanding = newOutstanding 122 | return accepted 123 | } 124 | 125 | func (bm *balanceManager) cleanup(stockId uint64) { 126 | if bm.toSell[stockId] == 0 { 127 | delete(bm.toSell, stockId) 128 | } 129 | if bm.held[stockId] == 0 { 130 | delete(bm.held, stockId) 131 | } 132 | } 133 | 134 | func (bm *balanceManager) addHeld(stockId, amount uint64) { 135 | held := bm.held[stockId] 136 | bm.held[stockId] = held + amount 137 | } 138 | 139 | func (bm *balanceManager) addToSell(stockId, amount uint64) { 140 | toSell := bm.toSell[stockId] 141 | bm.toSell[stockId] = toSell + amount 142 | } 143 | 144 | func (bm *balanceManager) cancelBuy(price, amount uint64) { 145 | bm.available += price * amount 146 | } 147 | 148 | func (bm *balanceManager) cancelSell(stockId, amount uint64) { 149 | bm.addHeld(stockId, amount) 150 | bm.addToSell(stockId, -amount) 151 | bm.cleanup(stockId) 152 | } 153 | 154 | func (bm *balanceManager) completeBuy(stockId, bidPrice, actualPrice, amount uint64) { 155 | bidTotal := bidPrice * amount 156 | actualTotal := actualPrice * amount 157 | bm.available += bidTotal 158 | bm.available -= actualTotal 159 | bm.current -= actualTotal 160 | bm.addHeld(stockId, amount) 161 | } 162 | 163 | func (bm *balanceManager) completeSell(stockId, price, amount uint64) { 164 | total := price * amount 165 | bm.current += total 166 | bm.available += total 167 | bm.addToSell(stockId, -amount) 168 | bm.cleanup(stockId) 169 | } 170 | 171 | func (bm *balanceManager) makeResponse(m *msg.Message, accepted bool, comment string) *Response { 172 | rm := receivedMessage{Message: *m, Accepted: accepted} 173 | current := bm.current 174 | available := bm.available 175 | held := mapToResponse(bm.held) 176 | toSell := mapToResponse(bm.toSell) 177 | os := make([]msg.Message, len(bm.outstanding)) 178 | copy(os, bm.outstanding) 179 | s := traderState{CurrentBalance: current, AvailableBalance: available, StocksHeld: held, StocksToSell: toSell, Outstanding: os} 180 | return &Response{State: s, Received: rm, Comment: comment} 181 | } 182 | 183 | func mapToResponse(in map[uint64]uint64) map[string]uint64 { 184 | out := make(map[string]uint64) 185 | for k, v := range in { 186 | ks := strconv.FormatUint(k, 10) 187 | out[ks] = v 188 | } 189 | return out 190 | } 191 | -------------------------------------------------------------------------------- /matcher/pqueue/rbtree.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | ) 9 | 10 | type rbtree struct { 11 | root *node 12 | } 13 | 14 | func (b *rbtree) String() string { 15 | return b.root.String() 16 | } 17 | 18 | func (b *rbtree) push(in *node) { 19 | if b.root == nil { 20 | b.root = in 21 | in.pp = &b.root 22 | return 23 | } 24 | b.root.push(in) 25 | } 26 | 27 | func (b *rbtree) peekMin() *node { 28 | n := b.root 29 | if n == nil { 30 | return nil 31 | } 32 | for n.left != nil { 33 | n = n.left 34 | } 35 | return n 36 | } 37 | 38 | func (b *rbtree) popMin() *node { 39 | if b.root != nil { 40 | n := b.peekMin() 41 | n.pop() 42 | n.other.pop() // Clear complementary rbtree 43 | return n 44 | } 45 | return nil 46 | } 47 | 48 | func (b *rbtree) peekMax() *node { 49 | n := b.root 50 | if n == nil { 51 | return nil 52 | } 53 | for n.right != nil { 54 | n = n.right 55 | } 56 | return n 57 | } 58 | 59 | func (b *rbtree) popMax() *node { 60 | if b.root != nil { 61 | n := b.peekMax() 62 | n.pop() 63 | n.other.pop() // Clear complementary rbtree 64 | return n 65 | } 66 | return nil 67 | } 68 | 69 | func (b *rbtree) cancel(val uint64) *node { 70 | n := b.get(val) 71 | if n == nil { 72 | return nil 73 | } 74 | n.pop() 75 | n.other.pop() 76 | return n 77 | } 78 | 79 | func (b *rbtree) Has(val uint64) bool { 80 | return b.get(val) != nil 81 | } 82 | 83 | func (b *rbtree) get(val uint64) *node { 84 | n := b.root 85 | for { 86 | if n == nil { 87 | return nil 88 | } 89 | if val == n.val { 90 | return n 91 | } 92 | if val < n.val { 93 | n = n.left 94 | } else { 95 | n = n.right 96 | } 97 | } 98 | panic("Unreachable") 99 | } 100 | 101 | type node struct { 102 | black bool 103 | // Tree fields 104 | val uint64 105 | left *node 106 | right *node 107 | parent *node 108 | pp **node 109 | // Limit queue fields 110 | next *node 111 | prev *node 112 | // OrderNode 113 | order *OrderNode 114 | // This is the other node tying order to another rbtree 115 | other *node 116 | } 117 | 118 | func (n *node) String() string { 119 | if n == nil { 120 | return "()" 121 | } 122 | valStr := strconv.Itoa(int(n.val)) 123 | colour := "R" 124 | if n.black { 125 | colour = "B" 126 | } 127 | b := bytes.NewBufferString("") 128 | b.WriteString("(") 129 | b.WriteString(valStr) 130 | b.WriteString(colour) 131 | if !(n.left == nil && n.right == nil) { 132 | b.WriteString(", ") 133 | b.WriteString(n.left.String()) 134 | b.WriteString(", ") 135 | b.WriteString(n.right.String()) 136 | } 137 | b.WriteString(")") 138 | return b.String() 139 | } 140 | 141 | func initNode(o *OrderNode, val uint64, n, other *node) { 142 | *n = node{val: val, order: o, other: other} 143 | n.next = n 144 | n.prev = n 145 | n.black = false 146 | } 147 | 148 | func (n *node) getOrderNode() *OrderNode { 149 | if n != nil { 150 | return n.order 151 | } 152 | return nil 153 | } 154 | 155 | func (n *node) isRed() bool { 156 | if n != nil { 157 | return !n.black 158 | } 159 | return false 160 | } 161 | 162 | func (n *node) isFree() bool { 163 | switch { 164 | case n.left != nil: 165 | return false 166 | case n.right != nil: 167 | return false 168 | case n.pp != nil: 169 | return false 170 | case n.next != n: 171 | return false 172 | case n.prev != n: 173 | return false 174 | } 175 | return true 176 | } 177 | 178 | func (n *node) isHead() bool { 179 | return n.pp != nil 180 | } 181 | 182 | func (n *node) getSibling() *node { 183 | p := n.parent 184 | if p == nil { 185 | return nil 186 | } 187 | if p.left == n { 188 | return p.right 189 | } 190 | return p.left 191 | } 192 | 193 | func (n *node) addLast(in *node) { 194 | last := n.next 195 | last.prev = in 196 | in.next = last 197 | in.prev = n 198 | n.next = in 199 | } 200 | 201 | func (n *node) giveParent(nn *node) { 202 | nn.parent = n.parent 203 | nn.pp = n.pp 204 | *nn.pp = nn 205 | n.parent = nil 206 | n.pp = nil 207 | } 208 | 209 | func (n *node) giveChildren(nn *node) { 210 | nn.left = n.left 211 | nn.right = n.right 212 | if nn.left != nil { 213 | nn.left.parent = nn 214 | nn.left.pp = &nn.left 215 | } 216 | if nn.right != nil { 217 | nn.right.parent = nn 218 | nn.right.pp = &nn.right 219 | } 220 | n.left = nil 221 | n.right = nil 222 | } 223 | 224 | func (n *node) givePosition(nn *node) { 225 | n.giveParent(nn) 226 | n.giveChildren(nn) 227 | nn.black = n.black 228 | // Guarantee: Each of n.parent/pp/left/right are now nil 229 | } 230 | 231 | func (n *node) push(in *node) { 232 | for { 233 | switch { 234 | case in.val == n.val: 235 | n.addLast(in) 236 | return 237 | case in.val < n.val: 238 | if n.left == nil { 239 | in.toLeftOf(n) 240 | repairInsert(n) 241 | return 242 | } else { 243 | n = n.left 244 | } 245 | case in.val > n.val: 246 | if n.right == nil { 247 | in.toRightOf(n) 248 | repairInsert(n) 249 | return 250 | } else { 251 | n = n.right 252 | } 253 | } 254 | } 255 | } 256 | 257 | func (n *node) detach() { 258 | p := n.parent 259 | s := n.getSibling() 260 | var nn *node 261 | switch { 262 | case n.right == nil && n.left == nil: 263 | *n.pp = nil 264 | n.pp = nil 265 | n.parent = nil 266 | case n.right == nil: 267 | nn = n.left 268 | n.giveParent(nn) 269 | n.left = nil 270 | case n.left == nil: 271 | nn = n.right 272 | n.giveParent(nn) 273 | n.right = nil 274 | default: 275 | nn = n.left.detachMax() 276 | n.givePosition(nn) 277 | return 278 | } 279 | repairDetach(p, n, s, nn) 280 | } 281 | 282 | func repairDetach(p, n, s, nn *node) { 283 | // Guarantee: Each of n.parent/pp/left/right are now nil 284 | if n.isRed() { 285 | return 286 | } 287 | if nn.isRed() { 288 | // Since n was black we can happily make its red replacement black 289 | nn.black = true 290 | return 291 | } 292 | repairToRoot(p, s) 293 | } 294 | 295 | func repairToRoot(p, s *node) { 296 | for p != nil { 297 | if s == nil { 298 | return 299 | } 300 | if s.isRed() { // Perform a rotation to make sibling black 301 | if p.left == s { 302 | p.rotateRight() 303 | s = p.left 304 | } else { 305 | p.rotateLeft() 306 | s = p.right 307 | } 308 | } 309 | pRed := p.isRed() 310 | slRed := s.left.isRed() 311 | srRed := s.right.isRed() 312 | if !slRed && !srRed { 313 | if pRed { // Sibling's children are black and parent is red 314 | p.black = true 315 | s.black = false 316 | return 317 | } else { // Sibling's children and parent are black, makes a black violation 318 | s.black = false 319 | } 320 | } else { // One of sibling's children is red 321 | if p.left == s { 322 | if slRed { 323 | p = p.rotateRight() 324 | } else { 325 | s.rotateLeft() 326 | p = p.rotateRight() 327 | } 328 | } else { 329 | if srRed { 330 | p = p.rotateLeft() 331 | } else { 332 | s.rotateRight() 333 | p = p.rotateLeft() 334 | } 335 | } 336 | p.black = !pRed 337 | p.left.black = true 338 | p.right.black = true 339 | return 340 | } 341 | s = p.getSibling() 342 | p = p.parent 343 | } 344 | } 345 | 346 | func repairInsert(n *node) { 347 | for n != nil { 348 | if n.left.isRed() && n.right.isRed() { 349 | n.flip() 350 | } 351 | if n.left.isRed() { 352 | if n.left.left.isRed() { 353 | n = n.rotateRight() 354 | } 355 | if n.left.right.isRed() { 356 | n.left.rotateLeft() 357 | n = n.rotateRight() 358 | } 359 | } 360 | if n.right.isRed() { 361 | if n.right.right.isRed() { 362 | n = n.rotateLeft() 363 | } 364 | if n.right.left.isRed() { 365 | n.right.rotateRight() 366 | n = n.rotateLeft() 367 | } 368 | } 369 | n = n.parent 370 | } 371 | } 372 | 373 | func (n *node) pop() { 374 | switch { 375 | case !n.isHead(): 376 | n.prev.next = n.next 377 | n.next.prev = n.prev 378 | n.parent = nil 379 | n.pp = nil 380 | n.left = nil 381 | n.right = nil 382 | case n.next != n: 383 | n.prev.next = n.next 384 | n.next.prev = n.prev 385 | nn := n.prev 386 | n.givePosition(nn) 387 | default: 388 | n.detach() 389 | } 390 | n.next = n 391 | n.prev = n 392 | // Guarantee: Each of n.parent/pp/left/right are now nil 393 | // Guarantee: Both n.left/right point to n 394 | } 395 | 396 | func (n *node) detachMax() *node { 397 | m := n 398 | for { 399 | if m.right == nil { 400 | break 401 | } 402 | m = m.right 403 | } 404 | m.detach() 405 | return m 406 | } 407 | 408 | func (n *node) toRightOf(to *node) { 409 | to.right = n 410 | if n != nil { 411 | n.parent = to 412 | n.pp = &to.right 413 | } 414 | } 415 | 416 | func (n *node) toLeftOf(to *node) { 417 | to.left = n 418 | if n != nil { 419 | n.parent = to 420 | n.pp = &to.left 421 | } 422 | } 423 | 424 | func (n *node) rotateLeft() *node { 425 | r := n.right 426 | n.giveParent(r) 427 | r.left.toRightOf(n) 428 | n.toLeftOf(r) 429 | r.black = n.black 430 | n.black = false 431 | return r 432 | } 433 | 434 | func (n *node) rotateRight() *node { 435 | l := n.left 436 | n.giveParent(l) 437 | l.right.toLeftOf(n) 438 | n.toRightOf(l) 439 | l.black = n.black 440 | n.black = false 441 | return l 442 | } 443 | 444 | func (n *node) flip() { 445 | n.black = !n.black 446 | n.left.black = !n.left.black 447 | n.right.black = !n.right.black 448 | } 449 | 450 | func (n *node) moveRedLeft() { 451 | n.flip() 452 | if n.right.left.isRed() { 453 | n.right.rotateRight() 454 | n.rotateLeft() 455 | n.parent.flip() 456 | } 457 | } 458 | 459 | func (n *node) moveRedRight() { 460 | n.flip() 461 | if n.left.left.isRed() { 462 | n.rotateRight() 463 | n.parent.flip() 464 | } 465 | } 466 | 467 | func validateRBT(rbt *rbtree) (err error) { 468 | defer func() { 469 | if r := recover(); r != nil { 470 | err = r.(error) 471 | } 472 | }() 473 | blackBalance(rbt.root, 0) 474 | testReds(rbt.root, 0) 475 | return nil 476 | } 477 | 478 | func blackBalance(n *node, depth int) int { 479 | if n == nil { 480 | return 0 481 | } 482 | lb := blackBalance(n.left, depth+1) 483 | rb := blackBalance(n.right, depth+1) 484 | if lb != rb { 485 | panic(errors.New(fmt.Sprintf("Unbalanced rbtree found at depth %d. Left: , %d Right: %d", depth, lb, rb))) 486 | } 487 | b := lb 488 | if !n.isRed() { 489 | b++ 490 | } 491 | return b 492 | } 493 | 494 | func testReds(n *node, depth int) { 495 | if n == nil { 496 | return 497 | } 498 | if n.isRed() && (n.left.isRed() || n.right.isRed()) && depth != 0 { 499 | panic(errors.New(fmt.Sprintf("Red violation found at depth %d", depth))) 500 | } 501 | testReds(n.left, depth+1) 502 | testReds(n.right, depth+1) 503 | } 504 | -------------------------------------------------------------------------------- /matcher/pqueue/tprioq_test.go: -------------------------------------------------------------------------------- 1 | package pqueue 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | "math/rand" 6 | "testing" 7 | ) 8 | 9 | // A function signature allowing us to switch easily between min and max queues 10 | type popperFun func(*testing.T, *rbtree, *rbtree, *pqueue) (*OrderNode, *OrderNode, *OrderNode) 11 | 12 | var msgMkr = msg.NewMessageMaker(1) 13 | 14 | func TestPush(t *testing.T) { 15 | // buys 16 | testPushAscDesc(t, 100, msg.BUY) 17 | // buys 18 | testPushSimple(t, 1, 1, 1, msg.BUY) 19 | testPushSimple(t, 4, 1, 1, msg.SELL) 20 | testPushSimple(t, 100, 10, 20, msg.BUY) 21 | testPushSimple(t, 100, 100, 10000, msg.SELL) 22 | testPushSimple(t, 1000, 100, 10000, msg.BUY) 23 | } 24 | 25 | func TestPushPopSimpleMin(t *testing.T) { 26 | // buys 27 | testPushPopSimple(t, 1, 1, 1, msg.BUY, maxPopper) 28 | testPushPopSimple(t, 4, 1, 1, msg.BUY, maxPopper) 29 | testPushPopSimple(t, 100, 10, 20, msg.BUY, maxPopper) 30 | testPushPopSimple(t, 100, 100, 10000, msg.BUY, maxPopper) 31 | testPushPopSimple(t, 1000, 100, 10000, msg.BUY, maxPopper) 32 | // sells 33 | testPushPopSimple(t, 1, 1, 1, msg.SELL, minPopper) 34 | testPushPopSimple(t, 100, 1, 1, msg.SELL, minPopper) 35 | testPushPopSimple(t, 100, 10, 20, msg.SELL, minPopper) 36 | testPushPopSimple(t, 100, 100, 10000, msg.SELL, minPopper) 37 | testPushPopSimple(t, 1000, 100, 10000, msg.SELL, minPopper) 38 | } 39 | 40 | func TestRandomPushPop(t *testing.T) { 41 | // buys 42 | testPushPopRandom(t, 1, 1, 1, msg.BUY, maxPopper) 43 | testPushPopRandom(t, 100, 1, 1, msg.BUY, maxPopper) 44 | testPushPopRandom(t, 100, 10, 20, msg.BUY, maxPopper) 45 | testPushPopRandom(t, 100, 100, 10000, msg.BUY, maxPopper) 46 | testPushPopRandom(t, 1000, 100, 10000, msg.BUY, maxPopper) 47 | // sells 48 | testPushPopRandom(t, 1, 1, 1, msg.SELL, minPopper) 49 | testPushPopRandom(t, 100, 1, 1, msg.SELL, minPopper) 50 | testPushPopRandom(t, 100, 10, 20, msg.SELL, minPopper) 51 | testPushPopRandom(t, 100, 100, 10000, msg.SELL, minPopper) 52 | testPushPopRandom(t, 1000, 100, 10000, msg.SELL, minPopper) 53 | } 54 | 55 | func TestAddRemoveSimple(t *testing.T) { 56 | // Buys 57 | testAddRemoveSimple(t, 1, 1, 1, msg.BUY) 58 | testAddRemoveSimple(t, 100, 1, 1, msg.BUY) 59 | testAddRemoveSimple(t, 100, 10, 20, msg.BUY) 60 | testAddRemoveSimple(t, 100, 100, 10000, msg.BUY) 61 | testAddRemoveSimple(t, 1000, 100, 10000, msg.BUY) 62 | // Sells 63 | testAddRemoveSimple(t, 1, 1, 1, msg.SELL) 64 | testAddRemoveSimple(t, 100, 1, 1, msg.SELL) 65 | testAddRemoveSimple(t, 100, 10, 20, msg.SELL) 66 | testAddRemoveSimple(t, 100, 100, 10000, msg.SELL) 67 | testAddRemoveSimple(t, 1000, 100, 10000, msg.SELL) 68 | } 69 | 70 | func TestAddRemoveRandom(t *testing.T) { 71 | // Buys 72 | testAddRemoveRandom(t, 1, 1, 1, msg.BUY) 73 | testAddRemoveRandom(t, 100, 1, 1, msg.BUY) 74 | testAddRemoveRandom(t, 100, 10, 20, msg.BUY) 75 | testAddRemoveRandom(t, 100, 100, 10000, msg.BUY) 76 | testAddRemoveRandom(t, 1000, 100, 10000, msg.BUY) 77 | // Sells 78 | testAddRemoveRandom(t, 1, 1, 1, msg.SELL) 79 | testAddRemoveRandom(t, 100, 1, 1, msg.SELL) 80 | testAddRemoveRandom(t, 100, 10, 20, msg.SELL) 81 | testAddRemoveRandom(t, 100, 100, 10000, msg.SELL) 82 | testAddRemoveRandom(t, 1000, 100, 10000, msg.SELL) 83 | } 84 | 85 | func testPushAscDesc(t *testing.T, pushCount int, kind msg.MsgKind) { 86 | priceTree := &rbtree{} 87 | guidTree := &rbtree{} 88 | validate(t, priceTree, guidTree) 89 | for i := 0; i < pushCount; i++ { 90 | o := &OrderNode{} 91 | o.CopyFrom(msgMkr.MkPricedOrder(uint64(i), kind)) 92 | priceTree.push(&o.priceNode) 93 | guidTree.push(&o.guidNode) 94 | validate(t, priceTree, guidTree) 95 | } 96 | for i := pushCount - 1; i >= 0; i-- { 97 | o := &OrderNode{} 98 | o.CopyFrom(msgMkr.MkPricedOrder(uint64(i), kind)) 99 | priceTree.push(&o.priceNode) 100 | guidTree.push(&o.guidNode) 101 | validate(t, priceTree, guidTree) 102 | } 103 | } 104 | 105 | func testPushSimple(t *testing.T, pushCount int, lowPrice, highPrice uint64, kind msg.MsgKind) { 106 | priceTree := &rbtree{} 107 | guidTree := &rbtree{} 108 | validate(t, priceTree, guidTree) 109 | for i := 0; i < pushCount; i++ { 110 | o := &OrderNode{} 111 | o.CopyFrom(msgMkr.MkPricedOrder(msgMkr.Between(lowPrice, highPrice), kind)) 112 | priceTree.push(&o.priceNode) 113 | guidTree.push(&o.guidNode) 114 | validate(t, priceTree, guidTree) 115 | } 116 | } 117 | 118 | func testPushPopSimple(t *testing.T, pushCount int, lowPrice, highPrice uint64, kind msg.MsgKind, popper popperFun) { 119 | priceTree := &rbtree{} 120 | guidTree := &rbtree{} 121 | validate(t, priceTree, guidTree) 122 | q := mkPrioq(lowPrice, highPrice) 123 | for i := 0; i < pushCount; i++ { 124 | o := &OrderNode{} 125 | o.CopyFrom(msgMkr.MkPricedOrder(msgMkr.Between(lowPrice, highPrice), kind)) 126 | priceTree.push(&o.priceNode) 127 | guidTree.push(&o.guidNode) 128 | validate(t, priceTree, guidTree) 129 | q.push(o) 130 | } 131 | for i := 0; i < pushCount; i++ { 132 | popCheck(t, priceTree, guidTree, q, popper) 133 | } 134 | } 135 | 136 | func testPushPopRandom(t *testing.T, pushCount int, lowPrice, highPrice uint64, kind msg.MsgKind, popper popperFun) { 137 | priceTree := &rbtree{} 138 | guidTree := &rbtree{} 139 | validate(t, priceTree, guidTree) 140 | q := mkPrioq(lowPrice, highPrice) 141 | r := rand.New(rand.NewSource(1)) 142 | for i := 0; i < pushCount; { 143 | n := r.Int() 144 | if n%2 == 0 || priceTree.peekMin() == nil { 145 | o := &OrderNode{} 146 | o.CopyFrom(msgMkr.MkPricedOrder(msgMkr.Between(lowPrice, highPrice), kind)) 147 | priceTree.push(&o.priceNode) 148 | guidTree.push(&o.guidNode) 149 | validate(t, priceTree, guidTree) 150 | q.push(o) 151 | i++ 152 | } else { 153 | popCheck(t, priceTree, guidTree, q, popper) 154 | } 155 | } 156 | for priceTree.peekMin() != nil { 157 | po := priceTree.popMax().getOrderNode() 158 | fo := q.popMax() 159 | if fo != po { 160 | t.Errorf("Mismatched Push/Pop pair") 161 | return 162 | } 163 | ensureFreed(t, po) 164 | validate(t, priceTree, guidTree) 165 | } 166 | } 167 | 168 | func testAddRemoveSimple(t *testing.T, pushCount int, lowPrice, highPrice uint64, kind msg.MsgKind) { 169 | priceTree := &rbtree{} 170 | guidTree := &rbtree{} 171 | validate(t, priceTree, guidTree) 172 | orderMap := make(map[uint64]*OrderNode) 173 | for i := 0; i < pushCount; i++ { 174 | o := &OrderNode{} 175 | o.CopyFrom(msgMkr.MkPricedOrder(msgMkr.Between(lowPrice, highPrice), kind)) 176 | priceTree.push(&o.priceNode) 177 | guidTree.push(&o.guidNode) 178 | validate(t, priceTree, guidTree) 179 | orderMap[o.Guid()] = o 180 | } 181 | drainTree(t, priceTree, guidTree, orderMap) 182 | } 183 | 184 | func testAddRemoveRandom(t *testing.T, pushCount int, lowPrice, highPrice uint64, kind msg.MsgKind) { 185 | priceTree := &rbtree{} 186 | guidTree := &rbtree{} 187 | validate(t, priceTree, guidTree) 188 | orderMap := make(map[uint64]*OrderNode) 189 | r := rand.New(rand.NewSource(1)) 190 | for i := 0; i < pushCount; { 191 | n := r.Int() 192 | if n%2 == 0 || guidTree.peekMin() == nil { 193 | o := &OrderNode{} 194 | o.CopyFrom(msgMkr.MkPricedOrder(msgMkr.Between(lowPrice, highPrice), kind)) 195 | priceTree.push(&o.priceNode) 196 | guidTree.push(&o.guidNode) 197 | validate(t, priceTree, guidTree) 198 | orderMap[o.Guid()] = o 199 | i++ 200 | } else { 201 | for g, o := range orderMap { 202 | po := guidTree.cancel(g).getOrderNode() 203 | delete(orderMap, g) 204 | if po != o { 205 | t.Errorf("Bad pop") 206 | } 207 | ensureFreed(t, po) 208 | validate(t, priceTree, guidTree) 209 | break 210 | } 211 | } 212 | } 213 | drainTree(t, priceTree, guidTree, orderMap) 214 | } 215 | 216 | func drainTree(t *testing.T, priceTree, guidTree *rbtree, orderMap map[uint64]*OrderNode) { 217 | for g := range orderMap { 218 | o := orderMap[g] 219 | po := guidTree.cancel(o.Guid()).getOrderNode() 220 | if po != o { 221 | t.Errorf("Bad pop") 222 | } 223 | ensureFreed(t, po) 224 | validate(t, priceTree, guidTree) 225 | } 226 | } 227 | 228 | func ensureFreed(t *testing.T, o *OrderNode) { 229 | if !o.priceNode.isFree() { 230 | t.Errorf("Price Node was not freed") 231 | } 232 | if !o.guidNode.isFree() { 233 | t.Errorf("Guid Node was not freed") 234 | } 235 | } 236 | 237 | // Quick check to ensure the rbtree's internal structure is valid 238 | func validate(t *testing.T, priceTree, guidTree *rbtree) { 239 | if err := validateRBT(priceTree); err != nil { 240 | t.Errorf("%s", err.Error()) 241 | } 242 | if err := validateRBT(guidTree); err != nil { 243 | t.Errorf("%s", err.Error()) 244 | } 245 | checkStructure(t, priceTree.root) 246 | checkStructure(t, guidTree.root) 247 | } 248 | 249 | func checkStructure(t *testing.T, n *node) { 250 | if n == nil { 251 | return 252 | } 253 | checkQueue(t, n) 254 | if *n.pp != n { 255 | t.Errorf("Parent pointer does not point to child node") 256 | } 257 | if n.left != nil { 258 | if n.val <= n.left.val { 259 | t.Errorf("Left value is greater than or equal to node value. Left value: %d Node value %d", n.left.val, n.val) 260 | } 261 | checkStructure(t, n.left) 262 | } 263 | if n.right != nil { 264 | if n.val >= n.right.val { 265 | t.Errorf("Right value is less than or equal to node value. Right value: %d Node value %d", n.right.val, n.val) 266 | } 267 | checkStructure(t, n.right) 268 | } 269 | } 270 | 271 | func checkQueue(t *testing.T, n *node) { 272 | curr := n.next 273 | prev := n 274 | for curr != n { 275 | if curr.prev != prev { 276 | t.Errorf("Bad queue next/prev pair") 277 | } 278 | if curr.pp != nil { 279 | t.Errorf("Internal queue node with non-nil parent pointer") 280 | } 281 | if curr.left != nil { 282 | t.Errorf("Internal queue node has non-nil left child") 283 | } 284 | if curr.right != nil { 285 | t.Errorf("Internal queue node has non-nil right child") 286 | } 287 | if curr.order == nil { 288 | t.Errorf("Internal queue node has nil OrderNode") 289 | } 290 | prev = curr 291 | curr = curr.next 292 | } 293 | } 294 | 295 | // Function to pop and peek and check that everything is in order 296 | func popCheck(t *testing.T, priceTree, guidTree *rbtree, q *pqueue, popper popperFun) { 297 | peek, pop, check := popper(t, priceTree, guidTree, q) 298 | if pop != check { 299 | t.Errorf("Mismatched push/pop pair") 300 | return 301 | } 302 | if pop != peek { 303 | t.Errorf("Mismatched peek/pop pair") 304 | return 305 | } 306 | validate(t, priceTree, guidTree) 307 | } 308 | 309 | // Helper functions for popping either the max or the min from our queues 310 | func maxPopper(t *testing.T, priceTree, guidTree *rbtree, q *pqueue) (peek, pop, check *OrderNode) { 311 | peek = priceTree.peekMax().getOrderNode() 312 | if !guidTree.Has(peek.Guid()) { 313 | t.Errorf("Guid rbtree does not contain peeked order") 314 | } 315 | pop = priceTree.popMax().getOrderNode() 316 | if guidTree.Has(peek.Price()) { 317 | t.Errorf("Guid rbtree still contains popped order") 318 | return 319 | } 320 | check = q.popMax() 321 | ensureFreed(t, pop) 322 | return 323 | } 324 | 325 | func minPopper(t *testing.T, priceTree, guidTree *rbtree, q *pqueue) (peek, pop, check *OrderNode) { 326 | peek = priceTree.peekMin().getOrderNode() 327 | if !guidTree.Has(peek.Guid()) { 328 | t.Errorf("Guid rbtree does not contain peeked order") 329 | } 330 | pop = priceTree.popMin().getOrderNode() 331 | check = q.popMin() 332 | ensureFreed(t, pop) 333 | return 334 | } 335 | -------------------------------------------------------------------------------- /matcher/testsuite.go: -------------------------------------------------------------------------------- 1 | package matcher 2 | 3 | import ( 4 | . "github.com/fmstephe/matching_engine/msg" 5 | "testing" 6 | ) 7 | 8 | var suiteMaker = NewMessageMaker(100) 9 | 10 | type MatchTester interface { 11 | Send(*testing.T, *Message) 12 | Expect(*testing.T, *Message) 13 | Cleanup(*testing.T) 14 | } 15 | 16 | type MatchTesterMaker interface { 17 | Make() MatchTester 18 | } 19 | 20 | func RunTestSuite(t *testing.T, mkr MatchTesterMaker) { 21 | testSellBuyMatch(t, mkr) 22 | testBuySellMatch(t, mkr) 23 | testBuyDoubleSellMatch(t, mkr) 24 | testSellDoubleBuyMatch(t, mkr) 25 | testMidPrice(t, mkr) 26 | testMidPriceBigSell(t, mkr) 27 | testMidPriceBigBuy(t, mkr) 28 | testTradeSeparateStocksI(t, mkr) 29 | testTradeSeparateStocksII(t, mkr) 30 | testSellCancelBuyNoMatch(t, mkr) 31 | testBuyCancelSellNoMatch(t, mkr) 32 | testBadCancelNotCancelled(t, mkr) 33 | testThreeBuysMatchedToOneSell(t, mkr) 34 | } 35 | 36 | func testSellBuyMatch(t *testing.T, mkr MatchTesterMaker) { 37 | mt := mkr.Make() 38 | defer mt.Cleanup(t) 39 | addLowBuys(t, mt, 5, 1) 40 | addHighSells(t, mt, 10, 1) 41 | // Add Sell 42 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 43 | mt.Send(t, s) 44 | // Add Buy 45 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 46 | mt.Send(t, b) 47 | // Full match 48 | es := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 49 | mt.Expect(t, es) 50 | eb := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 51 | mt.Expect(t, eb) 52 | } 53 | 54 | func testBuySellMatch(t *testing.T, mkr MatchTesterMaker) { 55 | mt := mkr.Make() 56 | defer mt.Cleanup(t) 57 | addLowBuys(t, mt, 5, 1) 58 | addHighSells(t, mt, 10, 1) 59 | // Add Buy 60 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 61 | mt.Send(t, b) 62 | // Add Sell 63 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 64 | mt.Send(t, s) 65 | // Full match 66 | eb := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 67 | mt.Expect(t, eb) 68 | es := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 69 | mt.Expect(t, es) 70 | } 71 | 72 | func testBuyDoubleSellMatch(t *testing.T, mkr MatchTesterMaker) { 73 | mt := mkr.Make() 74 | defer mt.Cleanup(t) 75 | addLowBuys(t, mt, 5, 1) 76 | addHighSells(t, mt, 10, 1) 77 | // Add Buy 78 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 2} 79 | mt.Send(t, b) 80 | // Add Sell1 81 | s1 := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 82 | mt.Send(t, s1) 83 | // Full match 84 | eb1 := &Message{Kind: PARTIAL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 85 | mt.Expect(t, eb1) 86 | es1 := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 87 | mt.Expect(t, es1) 88 | // Add Sell2 89 | s2 := &Message{Kind: SELL, TraderId: 2, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 90 | mt.Send(t, s2) 91 | // Full Match II 92 | eb2 := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 93 | mt.Expect(t, eb2) 94 | es2 := &Message{Kind: FULL, TraderId: 2, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 95 | mt.Expect(t, es2) 96 | } 97 | 98 | func testSellDoubleBuyMatch(t *testing.T, mkr MatchTesterMaker) { 99 | mt := mkr.Make() 100 | defer mt.Cleanup(t) 101 | addLowBuys(t, mt, 5, 1) 102 | addHighSells(t, mt, 10, 1) 103 | // Add Sell 104 | s := &Message{Kind: SELL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 2} 105 | mt.Send(t, s) 106 | // Add Buy1 107 | b1 := &Message{Kind: BUY, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 108 | mt.Send(t, b1) 109 | // Full match on the buy 110 | eb1 := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 111 | mt.Expect(t, eb1) 112 | // Partial match on the sell 113 | es1 := &Message{Kind: PARTIAL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 114 | mt.Expect(t, es1) 115 | // Add Buy2 116 | b2 := &Message{Kind: BUY, TraderId: 2, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 117 | mt.Send(t, b2) 118 | // Full Match II 119 | eb2 := &Message{Kind: FULL, TraderId: 2, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 120 | mt.Expect(t, eb2) 121 | es2 := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 122 | mt.Expect(t, es2) 123 | } 124 | 125 | func testMidPrice(t *testing.T, mkr MatchTesterMaker) { 126 | mt := mkr.Make() 127 | defer mt.Cleanup(t) 128 | addLowBuys(t, mt, 5, 1) 129 | addHighSells(t, mt, 10, 1) 130 | // Add Buy 131 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 9, Amount: 1} 132 | mt.Send(t, b) 133 | // Add Sell 134 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 6, Amount: 1} 135 | mt.Send(t, s) 136 | // Full match 137 | eb := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 138 | mt.Expect(t, eb) 139 | es := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 140 | mt.Expect(t, es) 141 | } 142 | 143 | func testMidPriceBigSell(t *testing.T, mkr MatchTesterMaker) { 144 | mt := mkr.Make() 145 | defer mt.Cleanup(t) 146 | addLowBuys(t, mt, 5, 1) 147 | addHighSells(t, mt, 10, 1) 148 | // Add Buy 149 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 9, Amount: 1} 150 | mt.Send(t, b) 151 | // Add Sell 152 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 6, Amount: 10} 153 | mt.Send(t, s) 154 | // Full match 155 | eb := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 156 | mt.Expect(t, eb) 157 | es := &Message{Kind: PARTIAL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 158 | mt.Expect(t, es) 159 | } 160 | 161 | func testMidPriceBigBuy(t *testing.T, mkr MatchTesterMaker) { 162 | mt := mkr.Make() 163 | defer mt.Cleanup(t) 164 | addLowBuys(t, mt, 5, 1) 165 | addHighSells(t, mt, 10, 1) 166 | // Add Buy 167 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 9, Amount: 10} 168 | mt.Send(t, b) 169 | // Add Sell 170 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 6, Amount: 1} 171 | mt.Send(t, s) 172 | // Full match 173 | eb := &Message{Kind: PARTIAL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 174 | mt.Expect(t, eb) 175 | es := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 176 | mt.Expect(t, es) 177 | } 178 | 179 | func testTradeSeparateStocksI(t *testing.T, mkr MatchTesterMaker) { 180 | mt := mkr.Make() 181 | defer mt.Cleanup(t) 182 | addLowBuys(t, mt, 5, 1) 183 | addHighSells(t, mt, 10, 1) 184 | // Add Sell stock 1 185 | s1 := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 186 | mt.Send(t, s1) 187 | // Add Buy stock 2 188 | b2 := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 2, Price: 7, Amount: 1} 189 | mt.Send(t, b2) 190 | // Add Sell stock 2 191 | s2 := &Message{Kind: SELL, TraderId: 2, TradeId: 2, StockId: 2, Price: 7, Amount: 1} 192 | mt.Send(t, s2) 193 | // Full match stock 2 194 | es2 := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 2, Price: 7, Amount: 1} 195 | mt.Expect(t, es2) 196 | eb2 := &Message{Kind: FULL, TraderId: 2, TradeId: 2, StockId: 2, Price: 7, Amount: 1} 197 | mt.Expect(t, eb2) 198 | // Add Buy stock 1 199 | b1 := &Message{Kind: BUY, TraderId: 1, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 200 | mt.Send(t, b1) 201 | // Full match stock 1 202 | eb1 := &Message{Kind: FULL, TraderId: 1, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 203 | mt.Expect(t, eb1) 204 | es1 := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 205 | mt.Expect(t, es1) 206 | } 207 | 208 | func testTradeSeparateStocksII(t *testing.T, mkr MatchTesterMaker) { 209 | mt := mkr.Make() 210 | defer mt.Cleanup(t) 211 | addLowBuys(t, mt, 5, 1) 212 | addHighSells(t, mt, 10, 1) 213 | // Add Sell stock 1 214 | s1 := &Message{Kind: SELL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 215 | mt.Send(t, s1) 216 | // Add Buy stock 2 217 | b1 := &Message{Kind: BUY, TraderId: 2, TradeId: 1, StockId: 2, Price: 7, Amount: 1} 218 | mt.Send(t, b1) 219 | // Add buy stock 1 220 | s2 := &Message{Kind: BUY, TraderId: 3, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 221 | mt.Send(t, s2) 222 | // Expect match on stock 1 223 | eb1 := &Message{Kind: FULL, TraderId: 3, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 224 | mt.Expect(t, eb1) 225 | es1 := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 226 | mt.Expect(t, es1) 227 | // Add sell stock 2 228 | b2 := &Message{Kind: SELL, TraderId: 4, TradeId: 1, StockId: 2, Price: 7, Amount: 1} 229 | mt.Send(t, b2) 230 | // Expect match on stock 2 231 | eb2 := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 2, Price: 7, Amount: 1} 232 | mt.Expect(t, eb2) 233 | es2 := &Message{Kind: FULL, TraderId: 4, TradeId: 1, StockId: 2, Price: 7, Amount: 1} 234 | mt.Expect(t, es2) 235 | } 236 | 237 | func testSellCancelBuyNoMatch(t *testing.T, mkr MatchTesterMaker) { 238 | mt := mkr.Make() 239 | defer mt.Cleanup(t) 240 | addLowBuys(t, mt, 5, 1) 241 | addHighSells(t, mt, 10, 1) 242 | // Add Sell 243 | s := &Message{Kind: SELL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 244 | mt.Send(t, s) 245 | // Cancel Sell 246 | cs := &Message{Kind: CANCEL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 247 | mt.Send(t, cs) 248 | // Expect Cancelled 249 | ec := &Message{Kind: CANCELLED, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 250 | mt.Expect(t, ec) 251 | // Add Buy 252 | b := &Message{Kind: BUY, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 253 | mt.Send(t, b) 254 | // Add Sell 255 | s2 := &Message{Kind: SELL, TraderId: 3, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 256 | mt.Send(t, s2) 257 | // Expect match for traders 1 and 3 258 | eb := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 259 | mt.Expect(t, eb) 260 | es := &Message{Kind: FULL, TraderId: 3, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 261 | mt.Expect(t, es) 262 | } 263 | 264 | func testBuyCancelSellNoMatch(t *testing.T, mkr MatchTesterMaker) { 265 | mt := mkr.Make() 266 | defer mt.Cleanup(t) 267 | addLowBuys(t, mt, 5, 1) 268 | addHighSells(t, mt, 10, 1) 269 | // Add Buy 270 | b := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 271 | mt.Send(t, b) 272 | // Cancel Buy 273 | cb := &Message{Kind: CANCEL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 274 | mt.Send(t, cb) 275 | // Expect Cancelled 276 | ec := &Message{Kind: CANCELLED, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 277 | mt.Expect(t, ec) 278 | // Add Sell 279 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 280 | mt.Send(t, s) 281 | // Add Buy 282 | b2 := &Message{Kind: BUY, TraderId: 3, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 283 | mt.Send(t, b2) 284 | // Expect match for traders 1 and 3 285 | eb := &Message{Kind: FULL, TraderId: 3, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 286 | mt.Expect(t, eb) 287 | es := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 288 | mt.Expect(t, es) 289 | } 290 | 291 | func testBadCancelNotCancelled(t *testing.T, mkr MatchTesterMaker) { 292 | mt := mkr.Make() 293 | defer mt.Cleanup(t) 294 | addLowBuys(t, mt, 5, 1) 295 | addHighSells(t, mt, 10, 1) 296 | // Cancel Buy that doesn't exist 297 | cb := &Message{Kind: CANCEL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 298 | mt.Send(t, cb) 299 | // Expect Not Cancelled 300 | ec := &Message{Kind: NOT_CANCELLED, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 301 | mt.Expect(t, ec) 302 | } 303 | 304 | // Defect found where the second PARTIAL sell match is getting filtered because it is identical to the first 305 | func testThreeBuysMatchedToOneSell(t *testing.T, mkr MatchTesterMaker) { 306 | mt := mkr.Make() 307 | defer mt.Cleanup(t) 308 | addLowBuys(t, mt, 5, 1) 309 | addHighSells(t, mt, 10, 1) 310 | // Three buys 311 | b1 := &Message{Kind: BUY, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 312 | mt.Send(t, b1) 313 | b2 := &Message{Kind: BUY, TraderId: 1, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 314 | mt.Send(t, b2) 315 | b3 := &Message{Kind: BUY, TraderId: 1, TradeId: 3, StockId: 1, Price: 7, Amount: 1} 316 | mt.Send(t, b3) 317 | // One big sell 318 | s := &Message{Kind: SELL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 3} 319 | mt.Send(t, s) 320 | // Expect full matches on all three buys 321 | eb1 := &Message{Kind: FULL, TraderId: 1, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 322 | mt.Expect(t, eb1) 323 | es1 := &Message{Kind: PARTIAL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 324 | mt.Expect(t, es1) 325 | eb2 := &Message{Kind: FULL, TraderId: 1, TradeId: 2, StockId: 1, Price: 7, Amount: 1} 326 | mt.Expect(t, eb2) 327 | es2 := &Message{Kind: PARTIAL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 328 | mt.Expect(t, es2) 329 | eb3 := &Message{Kind: FULL, TraderId: 1, TradeId: 3, StockId: 1, Price: 7, Amount: 1} 330 | mt.Expect(t, eb3) 331 | es3 := &Message{Kind: FULL, TraderId: 2, TradeId: 1, StockId: 1, Price: 7, Amount: 1} 332 | mt.Expect(t, es3) 333 | } 334 | 335 | func addLowBuys(t *testing.T, mt MatchTester, highestPrice uint64, stockId uint64) { 336 | buys := suiteMaker.MkBuys(suiteMaker.ValRangeFlat(10, 1, highestPrice), stockId) 337 | for i := range buys { 338 | mt.Send(t, &buys[i]) 339 | } 340 | } 341 | 342 | func addHighSells(t *testing.T, mt MatchTester, lowestPrice uint64, stockId uint64) { 343 | sells := suiteMaker.MkSells(suiteMaker.ValRangeFlat(10, lowestPrice, lowestPrice+10000), stockId) 344 | for i := range sells { 345 | mt.Send(t, &sells[i]) 346 | } 347 | } 348 | -------------------------------------------------------------------------------- /bin/svr/client/ttrader_unit_test.go.disabled: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/fmstephe/matching_engine/msg" 5 | "runtime" 6 | "testing" 7 | ) 8 | 9 | func TestMessageProcessBuyCancelCancelled(t *testing.T) { 10 | traderId := uint32(1) 11 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 12 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 13 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 14 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 15 | // Submit buys 16 | m1 := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 17 | canProcess(t, bm, m1) 18 | expectBalance(t, bm, 75, 100) 19 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 20 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 21 | validate(t, bm) 22 | // Cancel 23 | canProcess(t, bm, &msg.Message{Kind: msg.CANCEL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 5}) 24 | expectBalance(t, bm, 75, 100) 25 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 26 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 27 | validate(t, bm) 28 | // Confirm CANCELLED 29 | canProcess(t, bm, &msg.Message{Kind: msg.CANCELLED, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 5}) 30 | expectBalance(t, bm, 100, 100) 31 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 32 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 33 | validate(t, bm) 34 | } 35 | 36 | func TestMessageProcessBuyFullSimple(t *testing.T) { 37 | traderId := uint32(1) 38 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 39 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 40 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 41 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 42 | // Submit buy 43 | m1 := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 44 | canProcess(t, bm, m1) 45 | expectBalance(t, bm, 75, 100) 46 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 47 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 48 | validate(t, bm) 49 | // Full Match 50 | canProcess(t, bm, &msg.Message{Kind: msg.FULL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 5}) 51 | expectBalance(t, bm, 75, 75) 52 | expectInMap(t, bm.held, map[uint64]uint64{1: 15, 2: 10, 3: 10}) 53 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 54 | validate(t, bm) 55 | } 56 | 57 | func TestMessageProcessBuyFullDiffSimple(t *testing.T) { 58 | traderId := uint32(1) 59 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 60 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 61 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 62 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 63 | // Submit buy 64 | m1 := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 65 | canProcess(t, bm, m1) 66 | expectBalance(t, bm, 75, 100) 67 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 68 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 69 | validate(t, bm) 70 | // Full Match - at lower than bid price 71 | canProcess(t, bm, &msg.Message{Kind: msg.FULL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 4, Amount: 5}) 72 | expectBalance(t, bm, 80, 80) 73 | expectInMap(t, bm.held, map[uint64]uint64{1: 15, 2: 10, 3: 10}) 74 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 75 | validate(t, bm) 76 | } 77 | 78 | func TestMessageProcessBuyPartialSimple(t *testing.T) { 79 | traderId := uint32(1) 80 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 81 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 82 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 83 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 84 | // Submit some buys 85 | m1 := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 86 | canProcess(t, bm, m1) 87 | expectBalance(t, bm, 75, 100) 88 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 89 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 90 | validate(t, bm) 91 | // Partial Match 92 | canProcess(t, bm, &msg.Message{Kind: msg.PARTIAL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 2}) 93 | expectBalance(t, bm, 75, 90) 94 | expectInMap(t, bm.held, map[uint64]uint64{1: 12, 2: 10, 3: 10}) 95 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 96 | validate(t, bm) 97 | canProcess(t, bm, &msg.Message{Kind: msg.PARTIAL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 3}) 98 | expectBalance(t, bm, 75, 75) 99 | expectInMap(t, bm.held, map[uint64]uint64{1: 15, 2: 10, 3: 10}) 100 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 101 | validate(t, bm) 102 | } 103 | 104 | func TestMessageProcessBuyPartialDiffSimple(t *testing.T) { 105 | traderId := uint32(1) 106 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 107 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 108 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 109 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 110 | // Submit some buys 111 | m1 := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 112 | canProcess(t, bm, m1) 113 | expectBalance(t, bm, 75, 100) 114 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 115 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 116 | validate(t, bm) 117 | // Partial Matches at lower than bid price 118 | canProcess(t, bm, &msg.Message{Kind: msg.PARTIAL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 4, Amount: 2}) 119 | expectBalance(t, bm, 77, 92) 120 | expectInMap(t, bm.held, map[uint64]uint64{1: 12, 2: 10, 3: 10}) 121 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 122 | validate(t, bm) 123 | canProcess(t, bm, &msg.Message{Kind: msg.PARTIAL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 3, Amount: 3}) 124 | expectBalance(t, bm, 83, 83) 125 | expectInMap(t, bm.held, map[uint64]uint64{1: 15, 2: 10, 3: 10}) 126 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 127 | validate(t, bm) 128 | } 129 | 130 | func TestMessageProcessSellCancelCancelled(t *testing.T) { 131 | traderId := uint32(1) 132 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 133 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 134 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 135 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 136 | // Submit sell 137 | m1 := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 138 | canProcess(t, bm, m1) 139 | expectBalance(t, bm, 100, 100) 140 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 141 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 5}) 142 | validate(t, bm) 143 | // Cancel sell 144 | canProcess(t, bm, &msg.Message{Kind: msg.CANCEL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 5}) 145 | expectBalance(t, bm, 100, 100) 146 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 147 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 5}) 148 | validate(t, bm) 149 | // Confirm CANCELLED 150 | canProcess(t, bm, &msg.Message{Kind: msg.CANCELLED, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 5}) 151 | expectBalance(t, bm, 100, 100) 152 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 153 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 154 | validate(t, bm) 155 | } 156 | 157 | func TestMessageProcessSellFullSimple(t *testing.T) { 158 | traderId := uint32(1) 159 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 160 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 161 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 162 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 163 | // Submit sell 164 | m1 := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 165 | canProcess(t, bm, m1) 166 | expectBalance(t, bm, 100, 100) 167 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 168 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 5}) 169 | validate(t, bm) 170 | // Match sell FULL 171 | canProcess(t, bm, &msg.Message{Kind: msg.FULL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 5}) 172 | expectBalance(t, bm, 125, 125) 173 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 174 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 175 | validate(t, bm) 176 | } 177 | 178 | func TestMessageProcessSellFullDiff(t *testing.T) { 179 | traderId := uint32(1) 180 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 181 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 182 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 183 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 184 | // Submit sell 185 | m1 := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 186 | canProcess(t, bm, m1) 187 | expectBalance(t, bm, 100, 100) 188 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 189 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 5}) 190 | validate(t, bm) 191 | // Match sell FUll 192 | canProcess(t, bm, &msg.Message{Kind: msg.FULL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 7, Amount: 5}) 193 | expectBalance(t, bm, 135, 135) 194 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 195 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 196 | validate(t, bm) 197 | } 198 | 199 | func TestMessageProcessSellPartialSimple(t *testing.T) { 200 | traderId := uint32(1) 201 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 202 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 203 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 204 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 205 | // Submit sell 206 | m1 := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 207 | canProcess(t, bm, m1) 208 | expectBalance(t, bm, 100, 100) 209 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 210 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 5}) 211 | validate(t, bm) 212 | // Match sell PARTIAL 213 | canProcess(t, bm, &msg.Message{Kind: msg.PARTIAL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 5, Amount: 3}) 214 | expectBalance(t, bm, 115, 115) 215 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 216 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 2}) 217 | validate(t, bm) 218 | } 219 | 220 | func TestMessageProcessSellPartialDiff(t *testing.T) { 221 | traderId := uint32(1) 222 | bm := newBalanceManager(traderId, 100, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 223 | expectBalance(t, bm, 100, 100) // This test expects that 100 is 100 224 | expectInMap(t, bm.held, map[uint64]uint64{1: 10, 2: 10, 3: 10}) 225 | expectInMap(t, bm.toSell, map[uint64]uint64{}) 226 | // Submit sell 227 | m1 := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 5, Amount: 5} 228 | canProcess(t, bm, m1) 229 | expectBalance(t, bm, 100, 100) 230 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 231 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 5}) 232 | validate(t, bm) 233 | // Match sell PARTIAL 234 | canProcess(t, bm, &msg.Message{Kind: msg.PARTIAL, TraderId: traderId, TradeId: m1.TradeId, StockId: 1, Price: 7, Amount: 3}) 235 | expectBalance(t, bm, 121, 121) 236 | expectInMap(t, bm.held, map[uint64]uint64{1: 5, 2: 10, 3: 10}) 237 | expectInMap(t, bm.toSell, map[uint64]uint64{1: 2}) 238 | validate(t, bm) 239 | } 240 | 241 | func TestCanBuyCannotBuy(t *testing.T) { 242 | traderId := uint32(1) 243 | balance := uint64(100) 244 | bm := newBalanceManager(traderId, balance, map[uint64]uint64{}) 245 | // If the (stocks * price) <= bm.Balance then we can buy 246 | for amount := uint64(1); amount <= balance; amount++ { 247 | for price := uint64(1); price <= balance/amount; price++ { 248 | m := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: price, Amount: amount} 249 | canProcess(t, bm, m) 250 | c := &msg.Message{Kind: msg.CANCEL, TraderId: traderId, TradeId: m.TradeId, StockId: 1, Price: price, Amount: amount} 251 | canProcess(t, bm, c) 252 | cd := &msg.Message{Kind: msg.CANCELLED, TraderId: traderId, TradeId: m.TradeId, StockId: 1, Price: price, Amount: amount} 253 | canProcess(t, bm, cd) 254 | } 255 | } 256 | // If the (stocks * price) > bm.Balance then we can't buy 257 | for amount := uint64(1); amount <= balance; amount++ { 258 | initPrice := (balance / amount) + 1 259 | for price := initPrice; price <= initPrice*5; price++ { 260 | m := &msg.Message{Kind: msg.BUY, TraderId: traderId, TradeId: 0, StockId: 1, Price: price, Amount: amount} 261 | cannotProcess(t, bm, m) 262 | } 263 | } 264 | } 265 | 266 | func TestCanSellCannotSellAmount(t *testing.T) { 267 | traderId := uint32(1) 268 | balance := uint64(100) 269 | amount := uint64(100) 270 | bm := newBalanceManager(traderId, balance, map[uint64]uint64{1: amount}) 271 | // If the stock amount <= stock held then we can sell 272 | for i := uint64(1); i < amount; i++ { 273 | m := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 1, Amount: i} 274 | canProcess(t, bm, m) 275 | c := &msg.Message{Kind: msg.CANCEL, TraderId: traderId, TradeId: m.TradeId, StockId: 1, Price: 1, Amount: i} 276 | canProcess(t, bm, c) 277 | cd := &msg.Message{Kind: msg.CANCELLED, TraderId: traderId, TradeId: m.TradeId, StockId: 1, Price: 1, Amount: i} 278 | canProcess(t, bm, cd) 279 | } 280 | // If the stock amount >= stock held then we can't sell 281 | for i := amount + 1; i < amount*3; i++ { 282 | m := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: 1, Price: 1, Amount: i} 283 | cannotProcess(t, bm, m) 284 | } 285 | } 286 | 287 | func TestCanSellCannotSellStockId(t *testing.T) { 288 | traderId := uint32(1) 289 | balance := uint64(100) 290 | for stockId := uint64(2); stockId < 100; stockId++ { 291 | bm := newBalanceManager(traderId, balance, map[uint64]uint64{stockId: 1}) 292 | mTooLow := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: stockId - 1, Price: 1, Amount: 1} 293 | cannotProcess(t, bm, mTooLow) 294 | mTooHigh := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: stockId + 1, Price: 1, Amount: 1} 295 | cannotProcess(t, bm, mTooHigh) 296 | mJustRight := &msg.Message{Kind: msg.SELL, TraderId: traderId, TradeId: 0, StockId: stockId, Price: 1, Amount: 1} 297 | canProcess(t, bm, mJustRight) 298 | } 299 | } 300 | 301 | func validate(t *testing.T, bm *balanceManager) { 302 | // 1: current - available = sum(outstanding buys) 303 | totalBuys := 0 304 | for _, m := range bm.outstanding { 305 | if m.Kind == msg.BUY { 306 | totalBuys += int(m.Price) * int(m.Amount) 307 | } 308 | } 309 | diff := (bm.current - bm.available) 310 | if totalBuys != int(diff) { 311 | _, fname, lnum, _ := runtime.Caller(1) 312 | t.Errorf("Total buys outstanding: %d, current - available: %d\n%s:%d", totalBuys, diff, fname, lnum) 313 | } 314 | // 2: balance to sell = oustanding sells 315 | for stockId, amount := range bm.toSell { 316 | totalSells := 0 317 | for _, m := range bm.outstanding { 318 | if m.Kind == msg.SELL { 319 | totalSells += int(m.Amount) 320 | } 321 | } 322 | if totalSells != int(amount) { 323 | _, fname, lnum, _ := runtime.Caller(1) 324 | t.Errorf("%d to sell: %d, outstanding sells for %d: %d\n%s:%d", stockId, amount, stockId, totalSells, fname, lnum) 325 | } 326 | } 327 | } 328 | 329 | func canProcess(t *testing.T, bm *balanceManager, m *msg.Message) { 330 | expectCanProcess(t, bm, m, true) 331 | } 332 | 333 | func cannotProcess(t *testing.T, bm *balanceManager, m *msg.Message) { 334 | expectCanProcess(t, bm, m, false) 335 | } 336 | 337 | func expectCanProcess(t *testing.T, bm *balanceManager, m *msg.Message, can bool) { 338 | mod := "" 339 | if !can { 340 | mod = "not " 341 | } 342 | if can != bm.process(m) { 343 | _, fname, lnum, _ := runtime.Caller(2) 344 | t.Errorf("Expected to %sbe able to process %v - available: %d, current %d.\n%s:%d", mod, m, bm.available, bm.current, fname, lnum) 345 | } 346 | } 347 | 348 | func expectBalance(t *testing.T, bm *balanceManager, available, current uint64) { 349 | if available != bm.available { 350 | _, fname, lnum, _ := runtime.Caller(1) 351 | t.Errorf("Expected available %d, found %d\n%s:%d", available, bm.available, fname, lnum) 352 | } 353 | if current != bm.current { 354 | _, fname, lnum, _ := runtime.Caller(1) 355 | t.Errorf("Expected current %d, found %d\n%s:%d", current, bm.current, fname, lnum) 356 | } 357 | } 358 | 359 | func expectInMap(t *testing.T, expected, actual map[uint64]uint64) { 360 | for stock, eAmount := range expected { 361 | aAmount := actual[stock] 362 | if aAmount != eAmount { 363 | _, fname, lnum, _ := runtime.Caller(1) 364 | t.Errorf("Expected (stock %d: amount %d) but found (stock %d: amount %d)\n%s:%d", stock, eAmount, stock, aAmount, fname, lnum) 365 | } 366 | } 367 | for stock, aAmount := range actual { 368 | _, ok := expected[stock] 369 | if !ok { 370 | _, fname, lnum, _ := runtime.Caller(1) 371 | t.Errorf("(stock %d: amount %d) was not expected\n%s:%d", stock, aAmount, fname, lnum) 372 | } 373 | } 374 | } 375 | --------------------------------------------------------------------------------