mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-27 23:16:51 +00:00
OrderManager: Fix race condition in submit with ws (#1336)
* OrderManager: Fix race condition in submit with ws If the ws sees the order before processSubmittedOrder then it will have assigned it an internal order id already and added it to the store. Don't treat that as an error. Instead just use the newer ws details * OrderManager: Fix error comparisson Should always use errors.Is when possible * Tests: Simplify btcusd test pair declaration * OrderManager: Improve test readability * OrderManager: Add orderstore.getByDetail test * Return a fresh pointer from orderstore.getByDetail This protects the order.Details in the store from direct access. The use-case was to allow the returned objects to be references so that future changes to them would be reflected. However we're not ready yet to allow people to touch the orders directly, because they're not protected directly by a mutex, and nothing would stop consumers contaminating the integrity of the data. We can revisit this topic later atomicly, but it's definitely tangental to the cause of action for PR #1336. * Fix GetByDetail tests to assert a new pointer * OrderManager: Avoid possible lock races This fix internalises the getByDetail because the implication of moving lock ownership out of exists/getByDetail to consumers breaks the order store struct encapsulation in a way we really don't want to. It's also more efficient * Fix spelling mistake Co-authored-by: Scott <gloriousCode@users.noreply.github.com> * OrderManager: Fix TestSubmitOrder... description * OrderManager: Improve clarity of comment * OrderManager: Capitalise error message On failure to add to orderstore, capitalise the error message Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> --------- Co-authored-by: Scott <gloriousCode@users.noreply.github.com> Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
@@ -602,6 +602,14 @@ func (m *OrderManager) processSubmittedOrder(newOrderResp *order.SubmitResponse)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.orderStore.add(detail.CopyToPointer()); errors.Is(err, ErrOrdersAlreadyExists) {
|
||||
// Streamed by ws before we got here. Details from ws supersede since they are more recent.
|
||||
detail = m.orderStore.getByDetail(detail)
|
||||
} else if err != nil {
|
||||
// Non-fatal error: Unable to store order, but error does not need to be returned to caller
|
||||
log.Errorf(log.OrderMgr, "Unable to add %v order %v to orderStore: %s", detail.Exchange, detail.OrderID, err)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Exchange %s submitted order ID=%v [Ours: %v] pair=%v price=%v amount=%v quoteAmount=%v side=%v type=%v for time %v.",
|
||||
detail.Exchange,
|
||||
detail.OrderID,
|
||||
@@ -619,13 +627,7 @@ func (m *OrderManager) processSubmittedOrder(newOrderResp *order.SubmitResponse)
|
||||
m.orderStore.commsManager.PushEvent(base.Event{Type: "order", Message: msg})
|
||||
}
|
||||
|
||||
err = m.orderStore.add(detail.CopyToPointer())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to add %v order %v to orderStore: %s",
|
||||
detail.Exchange, detail.OrderID, err)
|
||||
}
|
||||
|
||||
return &OrderSubmitResponse{Detail: detail, InternalOrderID: id.String()}, nil
|
||||
return &OrderSubmitResponse{Detail: detail, InternalOrderID: detail.InternalOrderID.String()}, nil
|
||||
}
|
||||
|
||||
// processOrders iterates over all exchange orders via API
|
||||
@@ -1065,18 +1067,24 @@ func (s *store) upsert(od *order.Detail) (*OrderUpsertResponse, error) {
|
||||
|
||||
// exists verifies if the orderstore contains the provided order
|
||||
func (s *store) exists(det *order.Detail) bool {
|
||||
return s.getByDetail(det) != nil
|
||||
}
|
||||
|
||||
// getByDetail fetches an order from the store and returns it
|
||||
// returns nil if not found
|
||||
func (s *store) getByDetail(det *order.Detail) *order.Detail {
|
||||
if det == nil {
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
s.m.RLock()
|
||||
defer s.m.RUnlock()
|
||||
exchangeOrders := s.Orders[strings.ToLower(det.Exchange)]
|
||||
for x := range exchangeOrders {
|
||||
if exchangeOrders[x].OrderID == det.OrderID {
|
||||
return true
|
||||
for _, o := range exchangeOrders {
|
||||
if o.OrderID == det.OrderID {
|
||||
return o.CopyToPointer()
|
||||
}
|
||||
}
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add Adds an order to the orderStore for tracking the lifecycle
|
||||
@@ -1084,19 +1092,24 @@ func (s *store) add(det *order.Detail) error {
|
||||
if det == nil {
|
||||
return errNilOrder
|
||||
}
|
||||
|
||||
name := strings.ToLower(det.Exchange)
|
||||
_, err := s.exchangeManager.GetExchangeByName(name)
|
||||
if err != nil {
|
||||
if _, err := s.exchangeManager.GetExchangeByName(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.exists(det) { // TODO: Error on conflict; remove unnecessary locking.
|
||||
return ErrOrdersAlreadyExists
|
||||
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
|
||||
// Inline copy of getByDetail to avoid possible lock races
|
||||
for _, o := range s.Orders[name] {
|
||||
if o.OrderID == det.OrderID {
|
||||
return ErrOrdersAlreadyExists
|
||||
}
|
||||
}
|
||||
|
||||
// Untracked websocket orders will not have internalIDs yet
|
||||
det.GenerateInternalOrderID()
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
s.Orders[name] = append(s.Orders[name], det)
|
||||
if !det.AssetType.IsFutures() {
|
||||
return nil
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/convert"
|
||||
"github.com/thrasher-corp/gocryptotrader/config"
|
||||
@@ -28,6 +29,8 @@ type omfExchange struct {
|
||||
exchange.IBotExchange
|
||||
}
|
||||
|
||||
var btcusdPair = currency.NewPair(currency.BTC, currency.USD)
|
||||
|
||||
// CancelOrder overrides testExchange's cancel order function
|
||||
// to do the bare minimum required with no API calls or credentials required
|
||||
func (f omfExchange) CancelOrder(_ context.Context, _ *order.Cancel) error {
|
||||
@@ -477,17 +480,12 @@ func TestCancelOrder(t *testing.T) {
|
||||
t.Error("Expected error due to no order found")
|
||||
}
|
||||
|
||||
pair, err := currency.NewPairFromString("BTCUSD")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cancel := &order.Cancel{
|
||||
Exchange: testExchange,
|
||||
OrderID: "1337",
|
||||
Side: order.Sell,
|
||||
AssetType: asset.Spot,
|
||||
Pair: pair,
|
||||
Pair: btcusdPair,
|
||||
}
|
||||
err = m.Cancel(context.Background(), cancel)
|
||||
if !errors.Is(err, nil) {
|
||||
@@ -581,14 +579,9 @@ func TestSubmit(t *testing.T) {
|
||||
t.Error("Expected error from validation")
|
||||
}
|
||||
|
||||
pair, err := currency.NewPairFromString("BTCUSD")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m.cfg.EnforceLimitConfig = true
|
||||
m.cfg.AllowMarketOrders = false
|
||||
o.Pair = pair
|
||||
o.Pair = btcusdPair
|
||||
o.AssetType = asset.Spot
|
||||
o.Side = order.Buy
|
||||
o.Amount = 1
|
||||
@@ -646,6 +639,38 @@ func TestSubmit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubmitOrderAlreadyInStore ensures that if an order is submitted, but the WS sees the conf before processSubmittedOrder
|
||||
// then we don't error that it was there already
|
||||
func TestSubmitOrderAlreadyInStore(t *testing.T) {
|
||||
m := OrdersSetup(t)
|
||||
submitReq := &order.Submit{
|
||||
Type: order.Market,
|
||||
Pair: btcusdPair,
|
||||
AssetType: asset.Spot,
|
||||
Side: order.Buy,
|
||||
Amount: 1,
|
||||
Price: 1,
|
||||
Exchange: testExchange,
|
||||
}
|
||||
submitResp, err := submitReq.DeriveSubmitResponse("batman.obvs")
|
||||
assert.Nil(t, err, "Deriving a SubmitResp should not error")
|
||||
|
||||
id, err := uuid.NewV4()
|
||||
assert.Nil(t, err, "uuid should not error")
|
||||
d, err := submitResp.DeriveDetail(id)
|
||||
assert.Nil(t, err, "Derive Detail should not error")
|
||||
|
||||
d.ClientOrderID = "SecretSquirrelSauce"
|
||||
err = m.orderStore.add(d)
|
||||
assert.Nil(t, err, "Adding an order should not error")
|
||||
|
||||
resp, err := m.SubmitFakeOrder(submitReq, submitResp, false)
|
||||
|
||||
if assert.Nil(t, err, "SumbitFakeOrder should not error that the order is already in the store") {
|
||||
assert.Equal(t, d.ClientOrderID, resp.ClientOrderID, "resp should contain the ClientOrderID from the store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderManager_Modify(t *testing.T) {
|
||||
pair := currency.Pair{
|
||||
Base: currency.NewCode("XXXXX"),
|
||||
@@ -1669,3 +1694,37 @@ func TestProcessFuturesPositions(t *testing.T) {
|
||||
t.Errorf("received '%v', expected '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetByDetail tests orderstore.getByDetail
|
||||
func TestGetByDetail(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := OrdersSetup(t)
|
||||
assert.Nil(t, m.orderStore.getByDetail(nil), "Fetching a nil order should return nil")
|
||||
od := &order.Detail{
|
||||
Exchange: testExchange,
|
||||
AssetType: asset.Spot,
|
||||
OrderID: "AdmiralHarold",
|
||||
ClientOrderID: "DuskyLeafMonkey",
|
||||
}
|
||||
id := &order.Detail{
|
||||
Exchange: od.Exchange,
|
||||
OrderID: od.OrderID,
|
||||
}
|
||||
|
||||
assert.Nil(t, m.orderStore.getByDetail(od), "Fetching a non-stored order should return nil")
|
||||
assert.Nil(t, m.orderStore.add(od), "Adding the details should not error")
|
||||
|
||||
byOrig := m.orderStore.getByDetail(od)
|
||||
byID := m.orderStore.getByDetail(id)
|
||||
|
||||
if assert.NotNil(t, byOrig, od, "Retrieve by orig pointer should find a record") {
|
||||
assert.NotSame(t, byOrig, od, "Retrieve by orig pointer should return a new pointer")
|
||||
assert.Equal(t, od.ClientOrderID, byOrig.ClientOrderID, "Retrieve by orig pointer should contain the correct ClientOrderID")
|
||||
}
|
||||
|
||||
if assert.NotNil(t, byID, od, "Retrieve by new pointer should find a record") {
|
||||
assert.NotSame(t, byID, id, "Retrieve by new pointer should return a different new pointer than we passed in")
|
||||
assert.NotSame(t, byID, od, "Retrieve by new pointer should return a different new pointer than the original object")
|
||||
assert.Equal(t, od.ClientOrderID, byID.ClientOrderID, "Retrieve by id pointer should contain the correct ClientOrderID")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ const (
|
||||
|
||||
var b = &Bitfinex{}
|
||||
var wsAuthExecuted bool
|
||||
var btcusdPair currency.Pair
|
||||
var btcusdPair = currency.NewPair(currency.BTC, currency.USD)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
b.SetDefaults()
|
||||
@@ -66,11 +66,6 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
b.WebsocketSubdChannels = make(map[int]*stream.ChannelSubscription)
|
||||
|
||||
btcusdPair, err = currency.NewPairFromString("BTCUSD")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user