diff --git a/exchanges/fill/fill.go b/exchanges/fill/fill.go index 66dfa22f..aea552c1 100644 --- a/exchanges/fill/fill.go +++ b/exchanges/fill/fill.go @@ -1,5 +1,10 @@ package fill +import "errors" + +// ErrFeedDisabled is an error that indicates the fill feed is disabled +var ErrFeedDisabled = errors.New("fill feed disabled") + // Setup sets up the fill processor func (f *Fills) Setup(fillsFeedEnabled bool, c chan interface{}) { f.dataHandler = c @@ -14,9 +19,11 @@ func (f *Fills) Update(data ...Data) error { return nil } - if f.fillsFeedEnabled { - f.dataHandler <- data + if !f.fillsFeedEnabled { + return ErrFeedDisabled } + f.dataHandler <- data + return nil } diff --git a/exchanges/fill/fill_test.go b/exchanges/fill/fill_test.go new file mode 100644 index 00000000..e34733b3 --- /dev/null +++ b/exchanges/fill/fill_test.go @@ -0,0 +1,106 @@ +package fill + +import ( + "errors" + "testing" + "time" +) + +// TestSetup tests the setup function of the Fills struct +func TestSetup(t *testing.T) { + fill := &Fills{} + channel := make(chan interface{}) + fill.Setup(true, channel) + + if fill.dataHandler == nil { + t.Error("expected dataHandler to be set") + } + + if !fill.fillsFeedEnabled { + t.Error("expected fillsFeedEnabled to be true") + } +} + +// TestUpdateDisabledFeed tests the Update function when fillsFeedEnabled is false +func TestUpdateDisabledFeed(t *testing.T) { + channel := make(chan interface{}, 1) + fill := Fills{dataHandler: channel, fillsFeedEnabled: false} + + // Send a test data to the Update function + testData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} + if err := fill.Update(testData); !errors.Is(err, ErrFeedDisabled) { + t.Errorf("Expected ErrFeedDisabled, got %v", err) + } + + select { + case <-channel: + t.Errorf("Expected no data on channel, got data") + default: + // nothing to do + } +} + +// TestUpdate tests the Update function of the Fills struct. +func TestUpdate(t *testing.T) { + channel := make(chan interface{}, 1) + fill := &Fills{dataHandler: channel, fillsFeedEnabled: true} + receivedData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} + if err := fill.Update(receivedData); err != nil { + t.Errorf("Update returned error %v", err) + } + + select { + case data := <-channel: + dataSlice, ok := data.([]Data) + if !ok { + t.Errorf("expected []Data, got %T", data) + } + + if len(dataSlice) != 1 || dataSlice[0] != receivedData { + t.Errorf("expected data to be sent through channel") + } + default: + t.Errorf("No data sent to channel") + } +} + +// TestUpdateNoData tests the Update function with no Data objects +func TestUpdateNoData(t *testing.T) { + channel := make(chan interface{}, 1) + fill := &Fills{dataHandler: channel, fillsFeedEnabled: true} + if err := fill.Update(); err != nil { + t.Errorf("Update returned error %v", err) + } + + select { + case <-channel: + t.Errorf("Expected no data on channel, got data") + default: + // pass, nothing to do + } +} + +// TestUpdateMultipleData tests the Update function with multiple Data objects +func TestUpdateMultipleData(t *testing.T) { + channel := make(chan interface{}, 2) + fill := &Fills{dataHandler: channel, fillsFeedEnabled: true} + receivedData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} + receivedData2 := Data{Timestamp: time.Now(), Price: 18.2, Amount: 9.0} + if err := fill.Update(receivedData, receivedData2); err != nil { + t.Errorf("Update returned error %v", err) + } + + select { + case data := <-channel: + dataSlice, ok := data.([]Data) + if !ok { + t.Errorf("expected []Data, got %T", data) + } + + if len(dataSlice) != 2 || dataSlice[0] != receivedData || dataSlice[1] != receivedData2 { + t.Errorf("expected data to be sent through channel") + } + default: + t.Errorf("No data sent to channel") + } +} diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 9e17116f..814e12b7 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -50,6 +50,8 @@ func TestMain(m *testing.M) { gConf.API.Credentials.Key = apiKey gConf.API.Credentials.Secret = apiSecret g.Websocket = sharedtestvalues.NewTestWebsocket() + gConf.Features.Enabled.FillsFeed = true + gConf.Features.Enabled.TradeFeed = true err = g.Setup(gConf) if err != nil { log.Fatal("GateIO setup error", err) @@ -2584,7 +2586,7 @@ func TestWsTickerPushData(t *testing.T) { } } -const wsTradePushDataJSON = `{ "time": 1606292218, "channel": "spot.trades", "event": "update", "result": { "id": 309143071, "create_time": 1606292218, "create_time_ms": "1606292218213.4578", "side": "sell", "currency_pair": "GT_USDT", "amount": "16.4700000000", "price": "0.4705000000"}}` +const wsTradePushDataJSON = `{ "time": 1606292218, "channel": "spot.trades", "event": "update", "result": { "id": 309143071, "create_time": 1606292218, "create_time_ms": "1606292218213.4578", "side": "sell", "currency_pair": "BTC_USDT", "amount": "16.4700000000", "price": "0.4705000000"}}` func TestWsTradePushData(t *testing.T) { t.Parallel() diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 2d1e26d0..0afa6d21 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -51,7 +51,6 @@ const ( var defaultSubscriptions = []string{ spotTickerChannel, spotCandlesticksChannel, - spotTradesChannel, spotOrderbookTickerChannel, } @@ -197,6 +196,11 @@ func (g *Gateio) processTicker(incoming []byte, pushTime int64) error { } func (g *Gateio) processTrades(incoming []byte) error { + saveTradeData := g.IsSaveTradeDataEnabled() + if !saveTradeData && !g.IsTradeFeedEnabled() { + return nil + } + var data WsTrade err := json.Unmarshal(incoming, &data) if err != nil { @@ -207,7 +211,7 @@ func (g *Gateio) processTrades(incoming []byte) error { if err != nil { return err } - spotTradeData := trade.Data{ + tData := trade.Data{ Timestamp: data.CreateTimeMs.Time(), CurrencyPair: data.CurrencyPair, AssetType: asset.Spot, @@ -217,29 +221,16 @@ func (g *Gateio) processTrades(incoming []byte) error { Side: side, TID: strconv.FormatInt(data.ID, 10), } - assetPairEnabled := g.listOfAssetsCurrencyPairEnabledFor(data.CurrencyPair) - if assetPairEnabled[asset.Spot] { - err = trade.AddTradesToBuffer(g.Name, spotTradeData) - if err != nil { - return err - } - } - if assetPairEnabled[asset.Margin] { - marginTradeData := spotTradeData - marginTradeData.AssetType = asset.Margin - err = trade.AddTradesToBuffer(g.Name, marginTradeData) - if err != nil { - return err - } - } - if assetPairEnabled[asset.CrossMargin] { - crossMarginTradeData := spotTradeData - crossMarginTradeData.AssetType = asset.CrossMargin - err = trade.AddTradesToBuffer(g.Name, crossMarginTradeData) - if err != nil { - return err + + for _, assetType := range []asset.Item{asset.Spot, asset.Margin, asset.CrossMargin} { + if g.listOfAssetsCurrencyPairEnabledFor(data.CurrencyPair)[assetType] { + tData.AssetType = assetType + if err := g.Websocket.Trade.Update(saveTradeData, tData); err != nil { + return err + } } } + return nil } @@ -492,6 +483,10 @@ func (g *Gateio) processSpotOrders(data []byte) error { } func (g *Gateio) processUserPersonalTrades(data []byte) error { + if !g.IsFillsFeedEnabled() { + return nil + } + resp := struct { Time int64 `json:"time"` Channel string `json:"channel"` @@ -637,6 +632,11 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]stream.ChannelSubscription, e marginBalancesChannel, spotBalancesChannel}...) } + + if g.IsSaveTradeDataEnabled() || g.IsTradeFeedEnabled() { + channelsToSubscribe = append(channelsToSubscribe, spotTradesChannel) + } + var subscriptions []stream.ChannelSubscription var err error for i := range channelsToSubscribe { @@ -671,11 +671,7 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]stream.ChannelSubscription, e case spotOrderbookUpdateChannel: params["interval"] = kline.HundredMilliseconds } - if spotTradesChannel == channelsToSubscribe[i] { - if !g.IsSaveTradeDataEnabled() { - continue - } - } + fpair, err := g.FormatExchangeCurrency(pairs[j], asset.Spot) if err != nil { return nil, err diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index ef147818..2bdbdcf3 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -210,6 +210,8 @@ func (g *Gateio) Setup(exch *config.Exchange) error { GenerateSubscriptions: g.GenerateDefaultSubscriptions, ConnectionMonitorDelay: exch.ConnectionMonitorDelay, Features: &g.Features.Supports.WebsocketCapabilities, + FillsFeed: g.Features.Enabled.FillsFeed, + TradeFeed: g.Features.Enabled.TradeFeed, }) if err != nil { return err diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 115708c7..7fa8a653 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -440,6 +440,11 @@ func (g *Gateio) processFuturesTickers(data []byte, assetType asset.Item) error } func (g *Gateio) processFuturesTrades(data []byte, assetType asset.Item) error { + saveTradeData := g.IsSaveTradeDataEnabled() + if !saveTradeData && !g.IsTradeFeedEnabled() { + return nil + } + resp := struct { Time int64 `json:"time"` Channel string `json:"channel"` @@ -450,6 +455,7 @@ func (g *Gateio) processFuturesTrades(data []byte, assetType asset.Item) error { if err != nil { return err } + trades := make([]trade.Data, len(resp.Result)) for x := range resp.Result { trades[x] = trade.Data{ @@ -462,7 +468,7 @@ func (g *Gateio) processFuturesTrades(data []byte, assetType asset.Item) error { TID: strconv.FormatInt(resp.Result[x].ID, 10), } } - return trade.AddTradesToBuffer(g.Name, trades...) + return g.Websocket.Trade.Update(saveTradeData, trades...) } func (g *Gateio) processFuturesCandlesticks(data []byte, assetType asset.Item) error { @@ -677,6 +683,10 @@ func (g *Gateio) processFuturesOrdersPushData(data []byte, assetType asset.Item) } func (g *Gateio) procesFuturesUserTrades(data []byte, assetType asset.Item) error { + if !g.IsFillsFeedEnabled() { + return nil + } + resp := struct { Time int64 `json:"time"` Channel string `json:"channel"` diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index a111e1fb..a7b0dca2 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -661,6 +661,9 @@ func (g *Gateio) processOptionsOrderPushData(data []byte) error { } func (g *Gateio) processOptionsUserTradesPushData(data []byte) error { + if !g.IsFillsFeedEnabled() { + return nil + } resp := struct { Time int64 `json:"time"` Channel string `json:"channel"`