orderbook: Add GetTranches and GetPair methods to Depth type (#1324)

* orderbook: Add GetTranches and GetPair methods to Depth type (cherry-pick)

* glorious: nits

* linter: fix

---------

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2023-08-22 13:29:17 +10:00
committed by GitHub
parent 9c83231696
commit c5240153f9
4 changed files with 122 additions and 8 deletions

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/gofrs/uuid"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/dispatch"
"github.com/thrasher-corp/gocryptotrader/exchanges/alert"
"github.com/thrasher-corp/gocryptotrader/log"
@@ -18,6 +19,8 @@ var (
ErrOrderbookInvalid = errors.New("orderbook data integrity compromised")
// ErrInvalidAction defines and error when an action is invalid
ErrInvalidAction = errors.New("invalid action")
errInvalidBookDepth = errors.New("invalid book depth")
)
// Outbound restricts outbound usage of depth. NOTE: Type assert to
@@ -73,8 +76,8 @@ func (d *Depth) Retrieve() (*Base, error) {
return nil, d.validationError
}
return &Base{
Bids: d.bids.retrieve(),
Asks: d.asks.retrieve(),
Bids: d.bids.retrieve(0),
Asks: d.asks.retrieve(0),
Exchange: d.exchange,
Asset: d.asset,
Pair: d.pair,
@@ -711,3 +714,29 @@ func (d *Depth) GetImbalance() (float64, error) {
}
return (bidVolume - askVolume) / (bidVolume + askVolume), nil
}
// GetTranches returns the desired tranche for the required depth count. If
// count is 0, it will return the entire orderbook. Count == 1 will retrieve the
// best bid and ask. If the required count exceeds the orderbook depth, it will
// return the entire orderbook.
func (d *Depth) GetTranches(count int) (ask, bid []Item, err error) {
if count < 0 {
return nil, nil, errInvalidBookDepth
}
d.m.Lock()
defer d.m.Unlock()
if d.validationError != nil {
return nil, nil, d.validationError
}
return d.asks.retrieve(count), d.bids.retrieve(count), nil
}
// GetPair returns the pair associated with the depth
func (d *Depth) GetPair() (currency.Pair, error) {
d.m.Lock()
defer d.m.Unlock()
if d.pair.IsEmpty() {
return currency.Pair{}, currency.ErrCurrencyPairEmpty
}
return d.pair, nil
}

View File

@@ -2121,6 +2121,84 @@ func TestGetImbalance_Depth(t *testing.T) {
}
}
func TestGetTranches(t *testing.T) {
t.Parallel()
_, _, err := getInvalidDepth().GetTranches(0)
if !errors.Is(err, ErrOrderbookInvalid) {
t.Fatalf("received: '%v' but expected: '%v'", err, ErrOrderbookInvalid)
}
depth := NewDepth(id)
_, _, err = depth.GetTranches(-1)
if !errors.Is(err, errInvalidBookDepth) {
t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidBookDepth)
}
askT, bidT, err := depth.GetTranches(0)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if len(askT) != 0 {
t.Fatalf("received: '%v' but expected: '%v'", len(askT), 0)
}
if len(bidT) != 0 {
t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 0)
}
depth.LoadSnapshot(bid, ask, 0, time.Time{}, true)
askT, bidT, err = depth.GetTranches(0)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if len(askT) != 20 {
t.Fatalf("received: '%v' but expected: '%v'", len(askT), 20)
}
if len(bidT) != 20 {
t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 20)
}
askT, bidT, err = depth.GetTranches(5)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if len(askT) != 5 {
t.Fatalf("received: '%v' but expected: '%v'", len(askT), 5)
}
if len(bidT) != 5 {
t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 5)
}
}
func TestGetPair(t *testing.T) {
t.Parallel()
depth := NewDepth(id)
_, err := depth.GetPair()
if !errors.Is(err, currency.ErrCurrencyPairEmpty) {
t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty)
}
expected := currency.NewPair(currency.BTC, currency.WABI)
depth.pair = expected
pair, err := depth.GetPair()
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if !pair.Equal(expected) {
t.Fatalf("received: '%v' but expected: '%v'", pair, expected)
}
}
func getInvalidDepth() *Depth {
depth := NewDepth(id)
_ = depth.Invalidate(errors.New("invalid reasoning"))

View File

@@ -172,10 +172,13 @@ func (ll *linkedList) amount() (liquidity, value float64) {
}
// retrieve returns a full slice of contents from the linked list
func (ll *linkedList) retrieve() Items {
depth := make(Items, 0, ll.length)
for tip := ll.head; tip != nil; tip = tip.Next {
depth = append(depth, tip.Value)
func (ll *linkedList) retrieve(count int) Items {
if count == 0 || ll.length < count {
count = ll.length
}
depth := make(Items, count)
for i, tip := 0, ll.head; i < count && tip != nil; i, tip = i+1, tip.Next {
depth[i] = tip.Value
}
return depth
}

View File

@@ -460,11 +460,15 @@ func TestUpdateByID(t *testing.T) {
t.Fatalf("expecting %v but received %v", nil, err)
}
if a.retrieve()[1].Price == 0 {
if got := a.retrieve(2); len(got) != 2 || got[1].Price == 0 {
t.Fatal("price should not be replaced with zero")
}
if a.retrieve()[1].Amount != 1337 {
if got := a.retrieve(3); len(got) != 3 || got[1].Amount != 1337 {
t.Fatal("unexpected value for update")
}
if got := a.retrieve(1000); len(got) != 6 {
t.Fatal("unexpected value for update")
}
}