diff --git a/engine/order_manager.go b/engine/order_manager.go index 847f831a..8224a3d3 100644 --- a/engine/order_manager.go +++ b/engine/order_manager.go @@ -602,6 +602,14 @@ func (m *OrderManager) processSubmittedOrder(newOrderResp *order.SubmitResponse) return nil, err } + if err := m.orderStore.add(detail.CopyToPointer()); errors.Is(err, ErrOrdersAlreadyExists) { + // Streamed by ws before we got here. Details from ws supersede since they are more recent. + detail = m.orderStore.getByDetail(detail) + } else if err != nil { + // Non-fatal error: Unable to store order, but error does not need to be returned to caller + log.Errorf(log.OrderMgr, "Unable to add %v order %v to orderStore: %s", detail.Exchange, detail.OrderID, err) + } + msg := fmt.Sprintf("Exchange %s submitted order ID=%v [Ours: %v] pair=%v price=%v amount=%v quoteAmount=%v side=%v type=%v for time %v.", detail.Exchange, detail.OrderID, @@ -619,13 +627,7 @@ func (m *OrderManager) processSubmittedOrder(newOrderResp *order.SubmitResponse) m.orderStore.commsManager.PushEvent(base.Event{Type: "order", Message: msg}) } - err = m.orderStore.add(detail.CopyToPointer()) - if err != nil { - return nil, fmt.Errorf("unable to add %v order %v to orderStore: %s", - detail.Exchange, detail.OrderID, err) - } - - return &OrderSubmitResponse{Detail: detail, InternalOrderID: id.String()}, nil + return &OrderSubmitResponse{Detail: detail, InternalOrderID: detail.InternalOrderID.String()}, nil } // processOrders iterates over all exchange orders via API @@ -1065,18 +1067,24 @@ func (s *store) upsert(od *order.Detail) (*OrderUpsertResponse, error) { // exists verifies if the orderstore contains the provided order func (s *store) exists(det *order.Detail) bool { + return s.getByDetail(det) != nil +} + +// getByDetail fetches an order from the store and returns it +// returns nil if not found +func (s *store) getByDetail(det *order.Detail) *order.Detail { if det == nil { - return false + return nil } s.m.RLock() defer s.m.RUnlock() exchangeOrders := s.Orders[strings.ToLower(det.Exchange)] - for x := range exchangeOrders { - if exchangeOrders[x].OrderID == det.OrderID { - return true + for _, o := range exchangeOrders { + if o.OrderID == det.OrderID { + return o.CopyToPointer() } } - return false + return nil } // Add Adds an order to the orderStore for tracking the lifecycle @@ -1084,19 +1092,24 @@ func (s *store) add(det *order.Detail) error { if det == nil { return errNilOrder } + name := strings.ToLower(det.Exchange) - _, err := s.exchangeManager.GetExchangeByName(name) - if err != nil { + if _, err := s.exchangeManager.GetExchangeByName(name); err != nil { return err } - if s.exists(det) { // TODO: Error on conflict; remove unnecessary locking. - return ErrOrdersAlreadyExists + + s.m.Lock() + defer s.m.Unlock() + + // Inline copy of getByDetail to avoid possible lock races + for _, o := range s.Orders[name] { + if o.OrderID == det.OrderID { + return ErrOrdersAlreadyExists + } } // Untracked websocket orders will not have internalIDs yet det.GenerateInternalOrderID() - s.m.Lock() - defer s.m.Unlock() s.Orders[name] = append(s.Orders[name], det) if !det.AssetType.IsFutures() { return nil diff --git a/engine/order_manager_test.go b/engine/order_manager_test.go index 67ca2d86..cbdc7659 100644 --- a/engine/order_manager_test.go +++ b/engine/order_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/gofrs/uuid" "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/config" @@ -28,6 +29,8 @@ type omfExchange struct { exchange.IBotExchange } +var btcusdPair = currency.NewPair(currency.BTC, currency.USD) + // CancelOrder overrides testExchange's cancel order function // to do the bare minimum required with no API calls or credentials required func (f omfExchange) CancelOrder(_ context.Context, _ *order.Cancel) error { @@ -477,17 +480,12 @@ func TestCancelOrder(t *testing.T) { t.Error("Expected error due to no order found") } - pair, err := currency.NewPairFromString("BTCUSD") - if err != nil { - t.Fatal(err) - } - cancel := &order.Cancel{ Exchange: testExchange, OrderID: "1337", Side: order.Sell, AssetType: asset.Spot, - Pair: pair, + Pair: btcusdPair, } err = m.Cancel(context.Background(), cancel) if !errors.Is(err, nil) { @@ -581,14 +579,9 @@ func TestSubmit(t *testing.T) { t.Error("Expected error from validation") } - pair, err := currency.NewPairFromString("BTCUSD") - if err != nil { - t.Fatal(err) - } - m.cfg.EnforceLimitConfig = true m.cfg.AllowMarketOrders = false - o.Pair = pair + o.Pair = btcusdPair o.AssetType = asset.Spot o.Side = order.Buy o.Amount = 1 @@ -646,6 +639,38 @@ func TestSubmit(t *testing.T) { } } +// TestSubmitOrderAlreadyInStore ensures that if an order is submitted, but the WS sees the conf before processSubmittedOrder +// then we don't error that it was there already +func TestSubmitOrderAlreadyInStore(t *testing.T) { + m := OrdersSetup(t) + submitReq := &order.Submit{ + Type: order.Market, + Pair: btcusdPair, + AssetType: asset.Spot, + Side: order.Buy, + Amount: 1, + Price: 1, + Exchange: testExchange, + } + submitResp, err := submitReq.DeriveSubmitResponse("batman.obvs") + assert.Nil(t, err, "Deriving a SubmitResp should not error") + + id, err := uuid.NewV4() + assert.Nil(t, err, "uuid should not error") + d, err := submitResp.DeriveDetail(id) + assert.Nil(t, err, "Derive Detail should not error") + + d.ClientOrderID = "SecretSquirrelSauce" + err = m.orderStore.add(d) + assert.Nil(t, err, "Adding an order should not error") + + resp, err := m.SubmitFakeOrder(submitReq, submitResp, false) + + if assert.Nil(t, err, "SumbitFakeOrder should not error that the order is already in the store") { + assert.Equal(t, d.ClientOrderID, resp.ClientOrderID, "resp should contain the ClientOrderID from the store") + } +} + func TestOrderManager_Modify(t *testing.T) { pair := currency.Pair{ Base: currency.NewCode("XXXXX"), @@ -1669,3 +1694,37 @@ func TestProcessFuturesPositions(t *testing.T) { t.Errorf("received '%v', expected '%v'", err, nil) } } + +// TestGetByDetail tests orderstore.getByDetail +func TestGetByDetail(t *testing.T) { + t.Parallel() + m := OrdersSetup(t) + assert.Nil(t, m.orderStore.getByDetail(nil), "Fetching a nil order should return nil") + od := &order.Detail{ + Exchange: testExchange, + AssetType: asset.Spot, + OrderID: "AdmiralHarold", + ClientOrderID: "DuskyLeafMonkey", + } + id := &order.Detail{ + Exchange: od.Exchange, + OrderID: od.OrderID, + } + + assert.Nil(t, m.orderStore.getByDetail(od), "Fetching a non-stored order should return nil") + assert.Nil(t, m.orderStore.add(od), "Adding the details should not error") + + byOrig := m.orderStore.getByDetail(od) + byID := m.orderStore.getByDetail(id) + + if assert.NotNil(t, byOrig, od, "Retrieve by orig pointer should find a record") { + assert.NotSame(t, byOrig, od, "Retrieve by orig pointer should return a new pointer") + assert.Equal(t, od.ClientOrderID, byOrig.ClientOrderID, "Retrieve by orig pointer should contain the correct ClientOrderID") + } + + if assert.NotNil(t, byID, od, "Retrieve by new pointer should find a record") { + assert.NotSame(t, byID, id, "Retrieve by new pointer should return a different new pointer than we passed in") + assert.NotSame(t, byID, od, "Retrieve by new pointer should return a different new pointer than the original object") + assert.Equal(t, od.ClientOrderID, byID.ClientOrderID, "Retrieve by id pointer should contain the correct ClientOrderID") + } +} diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 40f3a0cb..8f0117df 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -36,7 +36,7 @@ const ( var b = &Bitfinex{} var wsAuthExecuted bool -var btcusdPair currency.Pair +var btcusdPair = currency.NewPair(currency.BTC, currency.USD) func TestMain(m *testing.M) { b.SetDefaults() @@ -66,11 +66,6 @@ func TestMain(m *testing.M) { } b.WebsocketSubdChannels = make(map[int]*stream.ChannelSubscription) - btcusdPair, err = currency.NewPairFromString("BTCUSD") - if err != nil { - log.Fatal(err) - } - os.Exit(m.Run()) }