mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-20 23:16:49 +00:00
orderbook: Add GetTranches and GetPair methods to Depth type (#1324)
* orderbook: Add GetTranches and GetPair methods to Depth type (cherry-pick) * glorious: nits * linter: fix --------- Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/dispatch"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/alert"
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
@@ -18,6 +19,8 @@ var (
|
||||
ErrOrderbookInvalid = errors.New("orderbook data integrity compromised")
|
||||
// ErrInvalidAction defines and error when an action is invalid
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
|
||||
errInvalidBookDepth = errors.New("invalid book depth")
|
||||
)
|
||||
|
||||
// Outbound restricts outbound usage of depth. NOTE: Type assert to
|
||||
@@ -73,8 +76,8 @@ func (d *Depth) Retrieve() (*Base, error) {
|
||||
return nil, d.validationError
|
||||
}
|
||||
return &Base{
|
||||
Bids: d.bids.retrieve(),
|
||||
Asks: d.asks.retrieve(),
|
||||
Bids: d.bids.retrieve(0),
|
||||
Asks: d.asks.retrieve(0),
|
||||
Exchange: d.exchange,
|
||||
Asset: d.asset,
|
||||
Pair: d.pair,
|
||||
@@ -711,3 +714,29 @@ func (d *Depth) GetImbalance() (float64, error) {
|
||||
}
|
||||
return (bidVolume - askVolume) / (bidVolume + askVolume), nil
|
||||
}
|
||||
|
||||
// GetTranches returns the desired tranche for the required depth count. If
|
||||
// count is 0, it will return the entire orderbook. Count == 1 will retrieve the
|
||||
// best bid and ask. If the required count exceeds the orderbook depth, it will
|
||||
// return the entire orderbook.
|
||||
func (d *Depth) GetTranches(count int) (ask, bid []Item, err error) {
|
||||
if count < 0 {
|
||||
return nil, nil, errInvalidBookDepth
|
||||
}
|
||||
d.m.Lock()
|
||||
defer d.m.Unlock()
|
||||
if d.validationError != nil {
|
||||
return nil, nil, d.validationError
|
||||
}
|
||||
return d.asks.retrieve(count), d.bids.retrieve(count), nil
|
||||
}
|
||||
|
||||
// GetPair returns the pair associated with the depth
|
||||
func (d *Depth) GetPair() (currency.Pair, error) {
|
||||
d.m.Lock()
|
||||
defer d.m.Unlock()
|
||||
if d.pair.IsEmpty() {
|
||||
return currency.Pair{}, currency.ErrCurrencyPairEmpty
|
||||
}
|
||||
return d.pair, nil
|
||||
}
|
||||
|
||||
@@ -2121,6 +2121,84 @@ func TestGetImbalance_Depth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTranches(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, err := getInvalidDepth().GetTranches(0)
|
||||
if !errors.Is(err, ErrOrderbookInvalid) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, ErrOrderbookInvalid)
|
||||
}
|
||||
|
||||
depth := NewDepth(id)
|
||||
|
||||
_, _, err = depth.GetTranches(-1)
|
||||
if !errors.Is(err, errInvalidBookDepth) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidBookDepth)
|
||||
}
|
||||
|
||||
askT, bidT, err := depth.GetTranches(0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if len(askT) != 0 {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", len(askT), 0)
|
||||
}
|
||||
|
||||
if len(bidT) != 0 {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 0)
|
||||
}
|
||||
|
||||
depth.LoadSnapshot(bid, ask, 0, time.Time{}, true)
|
||||
|
||||
askT, bidT, err = depth.GetTranches(0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if len(askT) != 20 {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", len(askT), 20)
|
||||
}
|
||||
|
||||
if len(bidT) != 20 {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 20)
|
||||
}
|
||||
|
||||
askT, bidT, err = depth.GetTranches(5)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if len(askT) != 5 {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", len(askT), 5)
|
||||
}
|
||||
|
||||
if len(bidT) != 5 {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", len(bidT), 5)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
depth := NewDepth(id)
|
||||
|
||||
_, err := depth.GetPair()
|
||||
if !errors.Is(err, currency.ErrCurrencyPairEmpty) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty)
|
||||
}
|
||||
|
||||
expected := currency.NewPair(currency.BTC, currency.WABI)
|
||||
depth.pair = expected
|
||||
|
||||
pair, err := depth.GetPair()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if !pair.Equal(expected) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", pair, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func getInvalidDepth() *Depth {
|
||||
depth := NewDepth(id)
|
||||
_ = depth.Invalidate(errors.New("invalid reasoning"))
|
||||
|
||||
@@ -172,10 +172,13 @@ func (ll *linkedList) amount() (liquidity, value float64) {
|
||||
}
|
||||
|
||||
// retrieve returns a full slice of contents from the linked list
|
||||
func (ll *linkedList) retrieve() Items {
|
||||
depth := make(Items, 0, ll.length)
|
||||
for tip := ll.head; tip != nil; tip = tip.Next {
|
||||
depth = append(depth, tip.Value)
|
||||
func (ll *linkedList) retrieve(count int) Items {
|
||||
if count == 0 || ll.length < count {
|
||||
count = ll.length
|
||||
}
|
||||
depth := make(Items, count)
|
||||
for i, tip := 0, ll.head; i < count && tip != nil; i, tip = i+1, tip.Next {
|
||||
depth[i] = tip.Value
|
||||
}
|
||||
return depth
|
||||
}
|
||||
|
||||
@@ -460,11 +460,15 @@ func TestUpdateByID(t *testing.T) {
|
||||
t.Fatalf("expecting %v but received %v", nil, err)
|
||||
}
|
||||
|
||||
if a.retrieve()[1].Price == 0 {
|
||||
if got := a.retrieve(2); len(got) != 2 || got[1].Price == 0 {
|
||||
t.Fatal("price should not be replaced with zero")
|
||||
}
|
||||
|
||||
if a.retrieve()[1].Amount != 1337 {
|
||||
if got := a.retrieve(3); len(got) != 3 || got[1].Amount != 1337 {
|
||||
t.Fatal("unexpected value for update")
|
||||
}
|
||||
|
||||
if got := a.retrieve(1000); len(got) != 6 {
|
||||
t.Fatal("unexpected value for update")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user