diff --git a/exchanges/orderbook/depth.go b/exchanges/orderbook/depth.go index 2aebe116..b4aa6fc8 100644 --- a/exchanges/orderbook/depth.go +++ b/exchanges/orderbook/depth.go @@ -7,6 +7,7 @@ import ( "time" "github.com/gofrs/uuid" + "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/dispatch" "github.com/thrasher-corp/gocryptotrader/exchanges/alert" "github.com/thrasher-corp/gocryptotrader/log" @@ -18,6 +19,8 @@ var ( ErrOrderbookInvalid = errors.New("orderbook data integrity compromised") // ErrInvalidAction defines and error when an action is invalid ErrInvalidAction = errors.New("invalid action") + + errInvalidBookDepth = errors.New("invalid book depth") ) // Outbound restricts outbound usage of depth. NOTE: Type assert to @@ -73,8 +76,8 @@ func (d *Depth) Retrieve() (*Base, error) { return nil, d.validationError } return &Base{ - Bids: d.bids.retrieve(), - Asks: d.asks.retrieve(), + Bids: d.bids.retrieve(0), + Asks: d.asks.retrieve(0), Exchange: d.exchange, Asset: d.asset, Pair: d.pair, @@ -711,3 +714,29 @@ func (d *Depth) GetImbalance() (float64, error) { } return (bidVolume - askVolume) / (bidVolume + askVolume), nil } + +// GetTranches returns the desired tranche for the required depth count. If +// count is 0, it will return the entire orderbook. Count == 1 will retrieve the +// best bid and ask. If the required count exceeds the orderbook depth, it will +// return the entire orderbook. +func (d *Depth) GetTranches(count int) (ask, bid []Item, err error) { + if count < 0 { + return nil, nil, errInvalidBookDepth + } + d.m.Lock() + defer d.m.Unlock() + if d.validationError != nil { + return nil, nil, d.validationError + } + return d.asks.retrieve(count), d.bids.retrieve(count), nil +} + +// GetPair returns the pair associated with the depth +func (d *Depth) GetPair() (currency.Pair, error) { + d.m.Lock() + defer d.m.Unlock() + if d.pair.IsEmpty() { + return currency.Pair{}, currency.ErrCurrencyPairEmpty + } + return d.pair, nil +} diff --git a/exchanges/orderbook/depth_test.go b/exchanges/orderbook/depth_test.go index ae46cf0f..e2c0bec3 100644 --- a/exchanges/orderbook/depth_test.go +++ b/exchanges/orderbook/depth_test.go @@ -2121,6 +2121,84 @@ func TestGetImbalance_Depth(t *testing.T) { } } +func TestGetTranches(t *testing.T) { + t.Parallel() + _, _, err := getInvalidDepth().GetTranches(0) + if !errors.Is(err, ErrOrderbookInvalid) { + t.Fatalf("received: '%v' but expected: '%v'", err, ErrOrderbookInvalid) + } + + depth := NewDepth(id) + + _, _, err = depth.GetTranches(-1) + if !errors.Is(err, errInvalidBookDepth) { + t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidBookDepth) + } + + askT, bidT, err := depth.GetTranches(0) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if len(askT) != 0 { + t.Fatalf("received: '%v' but expected: '%v'", len(askT), 0) + } + + if len(bidT) != 0 { + t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 0) + } + + depth.LoadSnapshot(bid, ask, 0, time.Time{}, true) + + askT, bidT, err = depth.GetTranches(0) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if len(askT) != 20 { + t.Fatalf("received: '%v' but expected: '%v'", len(askT), 20) + } + + if len(bidT) != 20 { + t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 20) + } + + askT, bidT, err = depth.GetTranches(5) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if len(askT) != 5 { + t.Fatalf("received: '%v' but expected: '%v'", len(askT), 5) + } + + if len(bidT) != 5 { + t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 5) + } +} + +func TestGetPair(t *testing.T) { + t.Parallel() + depth := NewDepth(id) + + _, err := depth.GetPair() + if !errors.Is(err, currency.ErrCurrencyPairEmpty) { + t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) + } + + expected := currency.NewPair(currency.BTC, currency.WABI) + depth.pair = expected + + pair, err := depth.GetPair() + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if !pair.Equal(expected) { + t.Fatalf("received: '%v' but expected: '%v'", pair, expected) + } +} + func getInvalidDepth() *Depth { depth := NewDepth(id) _ = depth.Invalidate(errors.New("invalid reasoning")) diff --git a/exchanges/orderbook/linked_list.go b/exchanges/orderbook/linked_list.go index 36baf2c4..becaa40a 100644 --- a/exchanges/orderbook/linked_list.go +++ b/exchanges/orderbook/linked_list.go @@ -172,10 +172,13 @@ func (ll *linkedList) amount() (liquidity, value float64) { } // retrieve returns a full slice of contents from the linked list -func (ll *linkedList) retrieve() Items { - depth := make(Items, 0, ll.length) - for tip := ll.head; tip != nil; tip = tip.Next { - depth = append(depth, tip.Value) +func (ll *linkedList) retrieve(count int) Items { + if count == 0 || ll.length < count { + count = ll.length + } + depth := make(Items, count) + for i, tip := 0, ll.head; i < count && tip != nil; i, tip = i+1, tip.Next { + depth[i] = tip.Value } return depth } diff --git a/exchanges/orderbook/linked_list_test.go b/exchanges/orderbook/linked_list_test.go index 9ff65dc1..2b6a8283 100644 --- a/exchanges/orderbook/linked_list_test.go +++ b/exchanges/orderbook/linked_list_test.go @@ -460,11 +460,15 @@ func TestUpdateByID(t *testing.T) { t.Fatalf("expecting %v but received %v", nil, err) } - if a.retrieve()[1].Price == 0 { + if got := a.retrieve(2); len(got) != 2 || got[1].Price == 0 { t.Fatal("price should not be replaced with zero") } - if a.retrieve()[1].Amount != 1337 { + if got := a.retrieve(3); len(got) != 3 || got[1].Amount != 1337 { + t.Fatal("unexpected value for update") + } + + if got := a.retrieve(1000); len(got) != 6 { t.Fatal("unexpected value for update") } }