mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
Kraken: Fix wsCancelOrders not erroring with order id (#1505)
* Kraken: Fix wsCancelOrders not erroring order id We were using the "cancel many" facility of the Kraken api. However since that doesn't actually report errors individually, it seems saner to just multiplex over it. We were going to get N+ responses anyway. Might as well send N+ requests * Common: Add ErrorCollector and methods
This commit is contained in:
@@ -578,6 +578,34 @@ func ExcludeError(err, excl error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrorCollector allows collecting a stream of errors from concurrent go routines
|
||||
// Users should call e.Wg.Done and send errors to e.C
|
||||
type ErrorCollector struct {
|
||||
C chan error
|
||||
Wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// CollectErrors returns an ErrorCollector with WaitGroup and Channel buffer set to n
|
||||
func CollectErrors(n int) *ErrorCollector {
|
||||
e := &ErrorCollector{
|
||||
C: make(chan error, n),
|
||||
}
|
||||
e.Wg.Add(n)
|
||||
return e
|
||||
}
|
||||
|
||||
// Collect runs waits for e.Wg to be Done, closes the error channel, and return a error collection
|
||||
func (e *ErrorCollector) Collect() (errs error) {
|
||||
e.Wg.Wait()
|
||||
close(e.C)
|
||||
for err := range e.C {
|
||||
if err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// StartEndTimeCheck provides some basic checks which occur
|
||||
// frequently in the codebase
|
||||
func StartEndTimeCheck(start, end time.Time) error {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/file"
|
||||
)
|
||||
|
||||
@@ -836,3 +837,22 @@ func TestGenerateRandomString(t *testing.T) {
|
||||
t.Error("GenerateRandomString() unexpected test validation result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorCollector exercises the error collector
|
||||
func TestErrorCollector(t *testing.T) {
|
||||
e := CollectErrors(4)
|
||||
for i := range 4 {
|
||||
go func() {
|
||||
if i%2 == 0 {
|
||||
e.C <- errors.New("Collected error")
|
||||
} else {
|
||||
e.C <- nil
|
||||
}
|
||||
e.Wg.Done()
|
||||
}()
|
||||
}
|
||||
v := e.Collect()
|
||||
errs, ok := v.(*multiError)
|
||||
require.True(t, ok, "Must return a multiError")
|
||||
assert.Len(t, errs.Unwrap(), 2, "Should have 2 errors")
|
||||
}
|
||||
|
||||
@@ -1987,7 +1987,7 @@ func TestSubscribe(t *testing.T) {
|
||||
{Channel: "btcusdt@trade"},
|
||||
}
|
||||
if mockTests {
|
||||
b = testexch.MockWSInstance[Binance](t, func(msg []byte, w *websocket.Conn) error {
|
||||
mock := func(msg []byte, w *websocket.Conn) error {
|
||||
var req WsPayload
|
||||
err := json.Unmarshal(msg, &req)
|
||||
require.NoError(t, err, "Unmarshal should not error")
|
||||
@@ -1995,7 +1995,8 @@ func TestSubscribe(t *testing.T) {
|
||||
assert.Equal(t, req.Params[0], channels[0].Channel, "Channel name should be correct")
|
||||
assert.Equal(t, req.Params[1], channels[1].Channel, "Channel name should be correct")
|
||||
return w.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(`{"result":null,"id":%d}`, req.ID)))
|
||||
})
|
||||
}
|
||||
b = testexch.MockWsInstance[Binance](t, testexch.CurryWsMockUpgrader(t, mock))
|
||||
} else {
|
||||
testexch.SetupWs(t, b)
|
||||
}
|
||||
@@ -2010,12 +2011,13 @@ func TestSubscribeBadResp(t *testing.T) {
|
||||
channels := []subscription.Subscription{
|
||||
{Channel: "moons@ticker"},
|
||||
}
|
||||
b := testexch.MockWSInstance[Binance](t, func(msg []byte, w *websocket.Conn) error { //nolint:govet // shadow
|
||||
mock := func(msg []byte, w *websocket.Conn) error {
|
||||
var req WsPayload
|
||||
err := json.Unmarshal(msg, &req)
|
||||
require.NoError(t, err, "Unmarshal should not error")
|
||||
return w.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(`{"result":{"error":"carrots"},"id":%d}`, req.ID)))
|
||||
})
|
||||
}
|
||||
b := testexch.MockWsInstance[Binance](t, testexch.CurryWsMockUpgrader(t, mock)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
|
||||
err := b.Subscribe(channels)
|
||||
assert.ErrorIs(t, err, stream.ErrSubscriptionFailure, "Subscribe should error ErrSubscriptionFailure")
|
||||
assert.ErrorIs(t, err, errUnknownError, "Subscribe should error errUnknownError")
|
||||
|
||||
@@ -2,6 +2,7 @@ package kraken
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/convert"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
"github.com/thrasher-corp/gocryptotrader/config"
|
||||
"github.com/thrasher-corp/gocryptotrader/core"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
|
||||
@@ -31,10 +31,11 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/ticker"
|
||||
testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange"
|
||||
"github.com/thrasher-corp/gocryptotrader/portfolio/withdraw"
|
||||
)
|
||||
|
||||
var k = &Kraken{}
|
||||
var k *Kraken
|
||||
var wsSetupRan bool
|
||||
|
||||
// Please add your own APIkeys to do correct due diligence testing.
|
||||
@@ -44,33 +45,23 @@ const (
|
||||
canManipulateRealOrders = false
|
||||
)
|
||||
|
||||
// TestSetup setup func
|
||||
func TestMain(m *testing.M) {
|
||||
k.SetDefaults()
|
||||
cfg := config.GetConfig()
|
||||
err := cfg.LoadConfig("../../testdata/configtest.json", true)
|
||||
if err != nil {
|
||||
k = new(Kraken)
|
||||
if err := testexch.TestInstance(k); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
krakenConfig, err := cfg.GetExchangeConfig("Kraken")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
krakenConfig.API.AuthenticatedSupport = true
|
||||
krakenConfig.API.Credentials.Key = apiKey
|
||||
krakenConfig.API.Credentials.Secret = apiSecret
|
||||
k.Websocket = sharedtestvalues.NewTestWebsocket()
|
||||
err = k.Setup(krakenConfig)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = k.UpdateTradablePairs(context.Background(), true)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
if apiKey != "" && apiSecret != "" {
|
||||
k.API.AuthenticatedSupport = true
|
||||
k.SetCredentials(apiKey, apiSecret, "", "", "", "")
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestUpdateTradablePairs(t *testing.T) {
|
||||
t.Parallel()
|
||||
testexch.UpdatePairsOnce(t, k)
|
||||
}
|
||||
|
||||
func TestGetCurrentServerTime(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := k.GetCurrentServerTime(context.Background())
|
||||
@@ -139,6 +130,7 @@ func TestUpdateTicker(t *testing.T) {
|
||||
|
||||
func TestUpdateTickers(t *testing.T) {
|
||||
t.Parallel()
|
||||
testexch.UpdatePairsOnce(t, k)
|
||||
ap, err := k.GetAvailablePairs(asset.Spot)
|
||||
require.NoError(t, err)
|
||||
err = k.CurrencyPairs.StorePairs(asset.Spot, ap, true)
|
||||
@@ -1253,7 +1245,9 @@ func TestGetWSToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWsAddOrder(t *testing.T) {
|
||||
setupWsTests(t)
|
||||
t.Parallel()
|
||||
sharedtestvalues.SkipTestIfCredentialsUnset(t, k, canManipulateRealOrders)
|
||||
testexch.SetupWs(t, k)
|
||||
_, err := k.wsAddOrder(&WsAddOrderRequest{
|
||||
OrderType: order.Limit.Lower(),
|
||||
OrderSide: order.Buy.Lower(),
|
||||
@@ -1265,11 +1259,42 @@ func TestWsAddOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWsCancelOrder(t *testing.T) {
|
||||
setupWsTests(t)
|
||||
if err := k.wsCancelOrders([]string{"1337"}); err != nil {
|
||||
t.Error(err)
|
||||
func mockWsCancelOrders(msg []byte, w *websocket.Conn) error {
|
||||
var req WsCancelOrderRequest
|
||||
if err := json.Unmarshal(msg, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
resp := WsCancelOrderResponse{
|
||||
Event: krakenWsCancelOrderStatus,
|
||||
Status: "ok",
|
||||
RequestID: req.RequestID,
|
||||
Count: int64(len(req.TransactionIDs)),
|
||||
}
|
||||
if len(req.TransactionIDs) == 0 || strings.Contains(req.TransactionIDs[0], "FISH") { // Reject anything that smells suspicious
|
||||
resp.Status = "error"
|
||||
resp.ErrorMessage = "[EOrder:Unknown order]"
|
||||
}
|
||||
respJSON, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w.WriteMessage(websocket.TextMessage, respJSON)
|
||||
}
|
||||
|
||||
func TestWsCancelOrders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
k := testexch.MockWsInstance[Kraken](t, curryWsMockUpgrader(t, mockWsCancelOrders)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
|
||||
require.True(t, k.IsWebsocketAuthenticationSupported(), "WS must be authenticated")
|
||||
|
||||
err := k.wsCancelOrders([]string{"RABBIT", "BATFISH", "SQUIRREL", "CATFISH", "MOUSE"})
|
||||
assert.ErrorIs(t, err, errCancellingOrder, "Should error cancelling order")
|
||||
assert.ErrorContains(t, err, "BATFISH", "Should error containing txn id")
|
||||
assert.ErrorContains(t, err, "CATFISH", "Should error containing txn id")
|
||||
assert.ErrorContains(t, err, "[EOrder:Unknown order]", "Should error containing server error")
|
||||
|
||||
err = k.wsCancelOrders([]string{"RABBIT", "SQUIRREL", "MOUSE"})
|
||||
assert.NoError(t, err, "Should not error with valid ids")
|
||||
}
|
||||
|
||||
func TestWsCancelAllOrders(t *testing.T) {
|
||||
@@ -1803,6 +1828,7 @@ func TestWsOwnTrades(t *testing.T) {
|
||||
func TestWsOpenOrders(t *testing.T) {
|
||||
t.Parallel()
|
||||
n := new(Kraken)
|
||||
testexch.UpdatePairsOnce(t, k)
|
||||
sharedtestvalues.TestFixtureToDataHandler(t, k, n, "testdata/wsOpenTrades.json", n.wsHandleData)
|
||||
seen := 0
|
||||
for reading := true; reading; {
|
||||
@@ -1884,18 +1910,6 @@ func TestWsAddOrderJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWsCancelOrderJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
pressXToJSON := []byte(`{
|
||||
"event": "cancelOrderStatus",
|
||||
"status": "ok"
|
||||
}`)
|
||||
err := k.wsHandleData(pressXToJSON)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTime(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test REST example
|
||||
@@ -2282,3 +2296,16 @@ func TestGetOpenInterest(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, resp)
|
||||
}
|
||||
|
||||
// curryWsMockUpgrader handles Kraken specific http auth token responses prior to handling off to standard Websocket upgrader
|
||||
func curryWsMockUpgrader(tb testing.TB, h testexch.WsMockFunc) http.HandlerFunc {
|
||||
tb.Helper()
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "GetWebSocketsToken") {
|
||||
_, err := w.Write([]byte(`{"result":{"token":"mockAuth"}}`))
|
||||
require.NoError(tb, err, "Write should not error")
|
||||
return
|
||||
}
|
||||
testexch.WsMockUpgrader(tb, w, r, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/buger/jsonparser"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/convert"
|
||||
@@ -64,6 +65,9 @@ var (
|
||||
pingRequest = WebsocketBaseEventRequest{Event: stream.Ping}
|
||||
m sync.Mutex
|
||||
errNoWebsocketOrderbookData = errors.New("no websocket orderbook data")
|
||||
errParsingWSField = errors.New("error parsing WS field")
|
||||
errUnknownError = errors.New("unknown error")
|
||||
errCancellingOrder = errors.New("error cancelling order")
|
||||
)
|
||||
|
||||
// Channels require a topic and a currency
|
||||
@@ -76,14 +80,6 @@ var defaultSubscribedChannels = []string{
|
||||
krakenWsSpread}
|
||||
var authenticatedChannels = []string{krakenWsOwnTrades, krakenWsOpenOrders}
|
||||
|
||||
var cancelOrdersStatusMutex sync.Mutex
|
||||
var cancelOrdersStatus = make(map[int64]*struct {
|
||||
Total int // total count of orders in wsCancelOrders request
|
||||
Successful int // numbers of Successfully canceled orders in wsCancelOrders request
|
||||
Unsuccessful int // numbers of Unsuccessfully canceled orders in wsCancelOrders request
|
||||
Error string // if at least one of requested order return fail, store error here
|
||||
})
|
||||
|
||||
// WsConnect initiates a websocket connection
|
||||
func (k *Kraken) WsConnect() error {
|
||||
if !k.Websocket.IsEnabled() || !k.IsEnabled() {
|
||||
@@ -188,26 +184,6 @@ func (k *Kraken) wsReadData(comms chan stream.Response) {
|
||||
}
|
||||
}
|
||||
|
||||
// awaitForCancelOrderResponses used to wait until all responses will received for appropriate CancelOrder request
|
||||
// success param = was the response from Kraken successful or not
|
||||
func isAwaitingCancelOrderResponses(requestID int64, success bool) bool {
|
||||
cancelOrdersStatusMutex.Lock()
|
||||
if stat, ok := cancelOrdersStatus[requestID]; ok {
|
||||
if success {
|
||||
cancelOrdersStatus[requestID].Successful++
|
||||
} else {
|
||||
cancelOrdersStatus[requestID].Unsuccessful++
|
||||
}
|
||||
|
||||
if stat.Successful+stat.Unsuccessful != stat.Total {
|
||||
cancelOrdersStatusMutex.Unlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
cancelOrdersStatusMutex.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
func (k *Kraken) wsHandleData(respRaw []byte) error {
|
||||
if strings.HasPrefix(string(respRaw), "[") {
|
||||
var dataResponse WebsocketDataResponse
|
||||
@@ -231,45 +207,19 @@ func (k *Kraken) wsHandleData(respRaw []byte) error {
|
||||
var eventResponse map[string]interface{}
|
||||
err := json.Unmarshal(respRaw, &eventResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s - err %s could not parse websocket data: %s",
|
||||
k.Name,
|
||||
err,
|
||||
respRaw)
|
||||
return fmt.Errorf("%s - err %s could not parse websocket data: %s", k.Name, err, respRaw)
|
||||
}
|
||||
if event, ok := eventResponse["event"]; ok {
|
||||
switch event {
|
||||
case stream.Pong, krakenWsHeartbeat:
|
||||
return nil
|
||||
case krakenWsCancelOrderStatus:
|
||||
var status WsCancelOrderResponse
|
||||
err := json.Unmarshal(respRaw, &status)
|
||||
id, err := jsonparser.GetInt(respRaw, "reqid")
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s - err %s unable to parse WsCancelOrderResponse: %s",
|
||||
k.Name,
|
||||
err,
|
||||
respRaw)
|
||||
return fmt.Errorf("%w 'reqid': %w from message: %s", errParsingWSField, err, respRaw)
|
||||
}
|
||||
|
||||
success := true
|
||||
if status.Status == "error" {
|
||||
success = false
|
||||
cancelOrdersStatusMutex.Lock()
|
||||
if _, ok := cancelOrdersStatus[status.RequestID]; ok {
|
||||
if cancelOrdersStatus[status.RequestID].Error == "" { // save the first error, if any
|
||||
cancelOrdersStatus[status.RequestID].Error = status.ErrorMessage
|
||||
}
|
||||
}
|
||||
cancelOrdersStatusMutex.Unlock()
|
||||
}
|
||||
|
||||
if isAwaitingCancelOrderResponses(status.RequestID, success) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// all responses handled, return results stored in cancelOrdersStatus
|
||||
if status.RequestID > 0 && !k.Websocket.Match.IncomingWithData(status.RequestID, respRaw) {
|
||||
return fmt.Errorf("can't send ws incoming data to Matched channel with RequestID: %d",
|
||||
status.RequestID)
|
||||
if !k.Websocket.Match.IncomingWithData(id, respRaw) {
|
||||
return fmt.Errorf("%v cancel order listener not found", id)
|
||||
}
|
||||
case krakenWsCancelAllOrderStatus:
|
||||
var status WsCancelOrderResponse
|
||||
@@ -1383,43 +1333,48 @@ func (k *Kraken) wsAddOrder(request *WsAddOrderRequest) (string, error) {
|
||||
return resp.TransactionID, nil
|
||||
}
|
||||
|
||||
// wsCancelOrders cancels one or more open orders passed in orderIDs param
|
||||
// wsCancelOrders cancels open orders concurrently
|
||||
// It does not use the multiple txId facility of the cancelOrder API because the errors are not specific
|
||||
func (k *Kraken) wsCancelOrders(orderIDs []string) error {
|
||||
errs := common.CollectErrors(len(orderIDs))
|
||||
for _, id := range orderIDs {
|
||||
go func() {
|
||||
defer errs.Wg.Done()
|
||||
errs.C <- k.wsCancelOrder(id)
|
||||
}()
|
||||
}
|
||||
|
||||
return errs.Collect()
|
||||
}
|
||||
|
||||
// wsCancelOrder cancels an open order
|
||||
func (k *Kraken) wsCancelOrder(orderID string) error {
|
||||
id := k.Websocket.AuthConn.GenerateMessageID(false)
|
||||
request := WsCancelOrderRequest{
|
||||
Event: krakenWsCancelOrder,
|
||||
Token: authToken,
|
||||
TransactionIDs: orderIDs,
|
||||
TransactionIDs: []string{orderID},
|
||||
RequestID: id,
|
||||
}
|
||||
|
||||
cancelOrdersStatus[id] = &struct {
|
||||
Total int
|
||||
Successful int
|
||||
Unsuccessful int
|
||||
Error string
|
||||
}{
|
||||
Total: len(orderIDs),
|
||||
}
|
||||
|
||||
defer delete(cancelOrdersStatus, id)
|
||||
|
||||
_, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
|
||||
resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%w %s: %w", errCancellingOrder, orderID, err)
|
||||
}
|
||||
|
||||
successful := cancelOrdersStatus[id].Successful
|
||||
|
||||
if cancelOrdersStatus[id].Error != "" || len(orderIDs) != successful { // strange Kraken logic ...
|
||||
var reason string
|
||||
if cancelOrdersStatus[id].Error != "" {
|
||||
reason = " Reason: " + cancelOrdersStatus[id].Error
|
||||
}
|
||||
return fmt.Errorf("%s cancelled %d out of %d orders.%s",
|
||||
k.Name, successful, len(orderIDs), reason)
|
||||
status, err := jsonparser.GetUnsafeString(resp, "status")
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w 'status': %w from message: %s", errParsingWSField, err, resp)
|
||||
} else if status == "ok" {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
err = errUnknownError
|
||||
if msg, pErr := jsonparser.GetUnsafeString(resp, "errorMessage"); pErr == nil && msg != "" {
|
||||
err = errors.New(msg)
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w %s: %w", errCancellingOrder, orderID, err)
|
||||
}
|
||||
|
||||
// wsCancelAllOrders cancels all opened orders
|
||||
|
||||
@@ -77,24 +77,28 @@ func MockHTTPInstance(e exchange.IBotExchange) error {
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
type wsMockFunc func(msg []byte, w *websocket.Conn) error
|
||||
// WsMockFunc is a websocket handler to be called with each websocket message
|
||||
type WsMockFunc func([]byte, *websocket.Conn) error
|
||||
|
||||
// MockWSInstance creates a new Exchange instance with a mock WS instance and HTTP server
|
||||
// It accepts an exchange package type argument and a mock WS function
|
||||
// MockWsInstance creates a new Exchange instance with a mock websocket instance and HTTP server
|
||||
// It accepts an exchange package type argument and a http.HandlerFunc
|
||||
// See CurryWsMockUpgrader for a convenient way to curry t and a ws mock function
|
||||
// It is expected to be run from any WS tests which need a specific response function
|
||||
func MockWSInstance[T any, PT interface {
|
||||
// No default subscriptions will be run since they disrupt unit tests
|
||||
func MockWsInstance[T any, PT interface {
|
||||
*T
|
||||
exchange.IBotExchange
|
||||
}](tb testing.TB, m wsMockFunc) *T {
|
||||
}](tb testing.TB, h http.HandlerFunc) *T {
|
||||
tb.Helper()
|
||||
|
||||
e := PT(new(T))
|
||||
require.NoError(tb, TestInstance(e), "TestInstance setup should not error")
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsMockWrapper(tb, w, r, m) }))
|
||||
s := httptest.NewServer(h)
|
||||
|
||||
b := e.GetBase()
|
||||
b.SkipAuthCheck = true
|
||||
b.API.AuthenticatedWebsocketSupport = true
|
||||
err := b.API.Endpoints.SetRunning("RestSpotURL", s.URL)
|
||||
require.NoError(tb, err, "Endpoints.SetRunning should not error for RestSpotURL")
|
||||
for _, auth := range []bool{true, false} {
|
||||
@@ -102,30 +106,40 @@ func MockWSInstance[T any, PT interface {
|
||||
require.NoErrorf(tb, err, "SetWebsocketURL should not error for auth: %v", auth)
|
||||
}
|
||||
|
||||
// Disable default subscriptions; Would disrupt unit tests
|
||||
b.Features.Subscriptions = []*subscription.Subscription{}
|
||||
// Exchanges which don't support subscription conf; Can be removed when all exchanges support sub conf
|
||||
b.Websocket.GenerateSubs = func() ([]subscription.Subscription, error) { return []subscription.Subscription{}, nil }
|
||||
|
||||
err = b.Websocket.Connect()
|
||||
require.NoError(tb, err, "Connect should not error")
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
// wsMockWrapper handles upgrading an initial HTTP request to WS, and then runs a for loop calling the mock func on each input
|
||||
func wsMockWrapper(tb testing.TB, w http.ResponseWriter, r *http.Request, m wsMockFunc) {
|
||||
// CurryWsMockUpgrader curries a WsMockUpgrader with a testing.TB and a mock func
|
||||
// bridging the gap between information known before the Server is created and during a request
|
||||
func CurryWsMockUpgrader(tb testing.TB, wsHandler WsMockFunc) http.HandlerFunc {
|
||||
tb.Helper()
|
||||
// TODO: This needs to move once this branch includes #1358, probably to use a new mock HTTP instance for kraken
|
||||
if strings.Contains(r.URL.Path, "GetWebSocketsToken") {
|
||||
_, err := w.Write([]byte(`{"result":{"token":"mockAuth"}}`))
|
||||
require.NoError(tb, err, "Write should not error")
|
||||
return
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
WsMockUpgrader(tb, w, r, wsHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// WsMockUpgrader handles upgrading an initial HTTP request to WS, and then runs a for loop calling the mock func on each input
|
||||
func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHandler WsMockFunc) {
|
||||
tb.Helper()
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
require.NoError(tb, err, "Upgrade connection should not error")
|
||||
defer c.Close()
|
||||
for {
|
||||
_, p, err := c.ReadMessage()
|
||||
if websocket.IsUnexpectedCloseError(err) {
|
||||
return
|
||||
}
|
||||
require.NoError(tb, err, "ReadMessage should not error")
|
||||
|
||||
err = m(p, c)
|
||||
err = wsHandler(p, c)
|
||||
assert.NoError(tb, err, "WS Mock Function should not error")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user