Websocket reconnection fix (#541)

* Adds potential fix for websocket reconnection failure

* Addr tests, we now don't return an error, this allows us to reuse existing if still in operation.

* update depends && go mod tidy

* adds in channel direction for parameter

* Add full subscriber function, increased test coverage, initiate go routine after calling routine instance check in connection monitor

* fix linter issue

* use protected methods for setting field variables

* removed function, added tests

* lock sub manipulation

* fix linter issue

* Added in transport idleconnection timeout to fix MACOS reconnection issue when all idle connections are consuming resources

* used protected methods to set underlying fields

* set variable via time.Duration param

* Added in lock around field variable in test

* Addr thrasher nits and expanded exchange tests

* Fix test

* Addr glorious nits

* go mod tidy

* Add a larger timeout for traffic monitor if the test runs slow
This commit is contained in:
Ryan O'Hara-Reid
2020-08-26 15:34:05 +10:00
committed by GitHub
parent 77df837b70
commit 870c8cb90e
9 changed files with 840 additions and 214 deletions

View File

@@ -23,6 +23,8 @@ const (
defaultTrafficPeriod = time.Second
)
var errClosedConnection = errors.New("use of closed network connection")
// New initialises the websocket struct
func New() *Websocket {
return &Websocket{
@@ -43,6 +45,10 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
return errors.New("websocket is nil")
}
if s == nil {
return errors.New("websocket setup is nil")
}
if !w.Init {
return fmt.Errorf("%s Websocket already initialised",
s.ExchangeName)
@@ -96,20 +102,11 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
if s.WebsocketTimeout < time.Second {
return fmt.Errorf("traffic timeout cannot be less than %s", time.Second)
}
w.trafficTimeout = s.WebsocketTimeout
if s.Features == nil {
return errors.New("feature set is nil")
}
w.ShutdownC = make(chan struct{})
w.Wg = new(sync.WaitGroup)
w.SetCanUseAuthenticatedEndpoints(s.AuthenticatedWebsocketAPISupport)
err = w.Initialise()
if err != nil {
return err
}
w.Orderbook.Setup(s.OrderbookBufferLimit,
s.BufferEnabled,
@@ -190,33 +187,28 @@ func (w *Websocket) Connect() error {
return fmt.Errorf("%v Websocket already connected",
w.exchangeName)
}
w.setConnectingStatus(true)
w.dataMonitor()
err := w.trafficMonitor()
if err != nil {
return err
}
w.trafficMonitor()
w.setConnectingStatus(true)
// flush any subscriptions from last connection if needed
w.subscriptionMutex.Lock()
w.subscriptions = nil
w.subscriptionMutex.Unlock()
err = w.connector()
err := w.connector()
if err != nil {
w.setConnectingStatus(false)
return fmt.Errorf("%v Error connecting %s",
w.exchangeName, err)
}
w.setConnectedStatus(true)
w.setConnectingStatus(false)
w.setInit(true)
if !w.IsConnectionMonitorRunning() {
go w.connectionMonitor()
w.connectionMonitor()
}
return nil
@@ -296,64 +288,65 @@ func (w *Websocket) connectionMonitor() {
return
}
w.setConnectionMonitorRunning(true)
timer := time.NewTimer(connectionMonitorDelay)
go func() {
timer := time.NewTimer(connectionMonitorDelay)
for {
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket: running connection monitor cycle\n",
w.exchangeName)
}
if !w.IsEnabled() {
for {
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket: connectionMonitor - websocket disabled, shutting down\n",
"%v websocket: running connection monitor cycle\n",
w.exchangeName)
}
if w.IsConnected() {
err := w.Shutdown()
if err != nil {
log.Error(log.WebsocketMgr, err)
if !w.IsEnabled() {
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket: connectionMonitor - websocket disabled, shutting down\n",
w.exchangeName)
}
if w.IsConnected() {
err := w.Shutdown()
if err != nil {
log.Error(log.WebsocketMgr, err)
}
}
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket: connection monitor exiting\n",
w.exchangeName)
}
timer.Stop()
w.setConnectionMonitorRunning(false)
return
}
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket: connection monitor exiting\n",
w.exchangeName)
select {
case err := <-w.ReadMessageErrors:
if isDisconnectionError(err) {
w.setInit(false)
log.Warnf(log.WebsocketMgr,
"%v websocket has been disconnected. Reason: %v",
w.exchangeName, err)
w.setConnectedStatus(false)
} else {
// pass off non disconnect errors to datahandler to manage
w.DataHandler <- err
}
case <-timer.C:
if !w.IsConnecting() && !w.IsConnected() {
err := w.Connect()
if err != nil {
log.Error(log.WebsocketMgr, err)
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(connectionMonitorDelay)
}
timer.Stop()
w.setConnectionMonitorRunning(false)
return
}
select {
case err := <-w.ReadMessageErrors:
// check if this error is a disconnection error
if isDisconnectionError(err) {
w.setInit(false)
log.Warnf(log.WebsocketMgr,
"%v websocket has been disconnected. Reason: %v",
w.exchangeName, err)
w.setConnectedStatus(false)
} else {
// pass off non disconnect errors to datahandler to manage
w.DataHandler <- err
}
case <-timer.C:
if !w.IsConnecting() && !w.IsConnected() {
err := w.Connect()
if err != nil {
log.Error(log.WebsocketMgr, err)
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(connectionMonitorDelay)
}
}
}()
}
// Shutdown attempts to shut down a websocket connection and associated routines
@@ -434,30 +427,28 @@ func (w *Websocket) FlushChannels() error {
return err
}
}
if len(subs) != 0 {
return w.SubscribeToChannels(subs)
}
return nil
} else if len(unsubs) == 0 {
if len(subs) == 0 {
return nil
}
return w.SubscribeToChannels(subs)
}
if len(subs) < 1 {
return nil
}
return w.SubscribeToChannels(subs)
} else if w.features.FullPayloadSubscribe {
// FullPayloadSubscribe means that the endpoint requires all
// subscriptions to be sent via the websocket connection e.g. if you are
// subscribed to ticker and orderbook but require trades as well, you
// would need to send ticker, orderbook and trades channel subscription
// messages.
} else if w.features.FullPayloadSubscribe {
newsubs, err := w.GenerateSubs()
if err != nil {
return err
}
if len(newsubs) != 0 {
// Purge subscription list as there will be conflicts
w.subscriptionMutex.Lock()
w.subscriptions = nil
w.subscriptionMutex.Unlock()
return w.SubscribeToChannels(newsubs)
}
return nil
@@ -473,9 +464,9 @@ func (w *Websocket) FlushChannels() error {
// trafficMonitor uses a timer of WebsocketTrafficLimitTime and once it expires,
// it will reconnect if the TrafficAlert channel has not received any data. The
// trafficTimer will reset on each traffic alert
func (w *Websocket) trafficMonitor() error {
func (w *Websocket) trafficMonitor() {
if w.IsTrafficMonitorRunning() {
return errors.New("traffic monitor already running")
return
}
w.setTrafficMonitorRunning(true)
w.Wg.Add(1)
@@ -513,32 +504,35 @@ func (w *Websocket) trafficMonitor() error {
}
trafficTimer.Stop()
w.Wg.Done()
err := w.Shutdown()
if err != nil {
log.Errorf(log.WebsocketMgr,
"%v websocket: trafficMonitor shutdown err: %s",
w.exchangeName, err)
if !w.IsConnecting() && w.IsConnected() {
err := w.Shutdown()
if err != nil {
log.Errorf(log.WebsocketMgr,
"%v websocket: trafficMonitor shutdown err: %s",
w.exchangeName, err)
}
}
w.setTrafficMonitorRunning(false)
return
}
// Routine pausing mechanism
go func(p chan struct{}) {
time.Sleep(defaultTrafficPeriod)
p <- struct{}{}
}(pause)
select {
case <-w.ShutdownC:
trafficTimer.Stop()
w.setTrafficMonitorRunning(false)
w.Wg.Done()
return
case <-pause:
if w.IsConnected() {
// Routine pausing mechanism
go func(p chan<- struct{}) {
time.Sleep(defaultTrafficPeriod)
p <- struct{}{}
}(pause)
select {
case <-w.ShutdownC:
trafficTimer.Stop()
w.setTrafficMonitorRunning(false)
w.Wg.Done()
return
case <-pause:
}
}
}
}()
return nil
}
func (w *Websocket) setConnectedStatus(b bool) {
@@ -706,18 +700,6 @@ func (w *Websocket) GetWebsocketURL() string {
return w.runningURL
}
// Initialise verifies status and connects
func (w *Websocket) Initialise() error {
if w.IsEnabled() {
if w.IsInit() {
return nil
}
return fmt.Errorf("%v websocket: already initialised", w.exchangeName)
}
w.setEnabled(w.enabled)
return nil
}
// SetProxyAddress sets websocket proxy address
func (w *Websocket) SetProxyAddress(proxyAddr string) error {
if proxyAddr != "" {
@@ -910,14 +892,8 @@ func isDisconnectionError(err error) bool {
if websocket.IsUnexpectedCloseError(err) {
return true
}
switch e := err.(type) {
case *websocket.CloseError:
return true
case *net.OpError:
if e.Err.Error() == "use of closed network connection" {
return false
}
return true
if _, ok := err.(*net.OpError); ok {
return !errors.Is(err, errClosedConnection)
}
return false
}

View File

@@ -72,6 +72,91 @@ var defaultSetup = &WebsocketSetup{
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
}
type dodgyConnection struct {
WebsocketConnection
}
// override websocket connection method to produce a wicked terrible error
func (d *dodgyConnection) Shutdown() error {
return errors.New("cannot shutdown due to some dastardly reason")
}
// override websocket connection method to produce a wicked terrible error
func (d *dodgyConnection) Connect() error {
return errors.New("cannot connect due to some dastardly reason")
}
func TestSetup(t *testing.T) {
var w *Websocket
err := w.Setup(nil)
if err == nil {
t.Fatal("error cannot be nil")
}
w = &Websocket{}
err = w.Setup(nil)
if err == nil {
t.Fatal("error cannot be nil")
}
w.Init = true
websocketSetup := &WebsocketSetup{}
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.Features = &protocol.Features{}
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.Features.Subscribe = true
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.Subscriber = func([]ChannelSubscription) error { return nil }
websocketSetup.Features.Unsubscribe = true
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.UnSubscriber = func([]ChannelSubscription) error { return nil }
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.DefaultURL = "test"
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.RunningURL = "http://www.google.com"
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.RunningURL = "wss://www.google.com"
websocketSetup.RunningURLAuth = "http://www.google.com"
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.RunningURLAuth = "wss://www.google.com"
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.ExchangeName = "testname"
err = w.Setup(websocketSetup)
if err == nil {
t.Fatal("error cannot be nil")
}
websocketSetup.WebsocketTimeout = time.Minute
err = w.Setup(websocketSetup)
if err != nil {
t.Fatal(err)
}
}
func TestTrafficMonitorTimeout(t *testing.T) {
ws := *New()
err := ws.Setup(defaultSetup)
@@ -80,15 +165,16 @@ func TestTrafficMonitorTimeout(t *testing.T) {
}
ws.trafficTimeout = time.Second
ws.ShutdownC = make(chan struct{})
err = ws.trafficMonitor()
if err != nil {
t.Fatal(err)
ws.trafficMonitor()
if !ws.IsTrafficMonitorRunning() {
t.Fatal("traffic monitor should be running")
}
// try to add another traffic monitor
err = ws.trafficMonitor()
if err == nil {
t.Fatal("expected not allowed")
ws.trafficMonitor()
if !ws.IsTrafficMonitorRunning() {
t.Fatal("traffic monitor should be running")
}
// Deploy traffic alert
ws.TrafficAlert <- struct{}{}
time.Sleep(time.Second * 2)
@@ -112,11 +198,14 @@ func TestIsDisconnectionError(t *testing.T) {
}
isADisconnectionError = isDisconnectionError(&net.OpError{
Op: "",
Net: "",
Source: nil,
Addr: nil,
Err: errors.New("errorText"),
Err: errClosedConnection,
})
if isADisconnectionError {
t.Error("It's not")
}
isADisconnectionError = isDisconnectionError(&net.OpError{
Err: errors.New("errText"),
})
if !isADisconnectionError {
t.Error("It is")
@@ -124,8 +213,35 @@ func TestIsDisconnectionError(t *testing.T) {
}
func TestConnectionMessageErrors(t *testing.T) {
var wsWrong = &Websocket{}
err := wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}
wsWrong.connector = func() error { return nil }
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}
wsWrong.setEnabled(true)
wsWrong.setConnectingStatus(true)
wsWrong.Wg = &sync.WaitGroup{}
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}
wsWrong.setConnectedStatus(false)
wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") }
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}
ws := *New()
err := ws.Setup(defaultSetup)
err = ws.Setup(defaultSetup)
if err != nil {
t.Fatal(err)
}
@@ -165,8 +281,8 @@ outer:
}
func TestWebsocket(t *testing.T) {
ws := Websocket{}
err := ws.Setup(&WebsocketSetup{
wsInit := Websocket{}
err := wsInit.Setup(&WebsocketSetup{
ExchangeName: "test",
Enabled: true,
})
@@ -174,15 +290,33 @@ func TestWebsocket(t *testing.T) {
t.Errorf("Expected 'test Websocket already initialised', received %v", err)
}
ws = *New()
ws := *New()
err = ws.SetProxyAddress("garbagio")
if err == nil {
t.Error("error cannot be nil")
}
ws.Conn = &WebsocketConnection{}
ws.AuthConn = &WebsocketConnection{}
ws.setEnabled(true)
err = ws.SetProxyAddress("https://192.168.0.1:1337")
if err != nil {
if err == nil {
t.Error("error cannot be nil")
}
ws.setConnectedStatus(true)
ws.ShutdownC = make(chan struct{})
ws.Wg = &sync.WaitGroup{}
err = ws.SetProxyAddress("https://192.168.0.1:1336")
if err == nil {
t.Error("SetProxyAddress", err)
}
err = ws.SetProxyAddress("https://192.168.0.1:1336")
if err == nil {
t.Error("SetProxyAddress", err)
}
ws.setEnabled(false)
// removing proxy
err = ws.SetProxyAddress("")
if err != nil {
@@ -234,23 +368,65 @@ func TestWebsocket(t *testing.T) {
if err == nil {
t.Fatal("should not be connected to able to shut down")
}
ws.verbose = true
ws.setConnectedStatus(true)
ws.Conn = &dodgyConnection{}
err = ws.Shutdown()
if err == nil {
t.Fatal("error cannot be nil")
}
ws.Conn = &WebsocketConnection{}
ws.setConnectedStatus(true)
ws.AuthConn = &dodgyConnection{}
err = ws.Shutdown()
if err == nil {
t.Fatal("error cannot be nil ")
}
ws.AuthConn = &WebsocketConnection{}
ws.setConnectedStatus(false)
// -- Normal connect
err = ws.Connect()
if err != nil {
t.Fatal("WebsocketSetup", err)
}
ws.defaultURL = "ws://demos.kaazing.com/echo"
ws.defaultURLAuth = "ws://demos.kaazing.com/echo"
err = ws.SetWebsocketURL("", false, false)
if err != nil {
t.Fatal(err)
}
err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", false, false)
if err != nil {
t.Fatal(err)
}
err = ws.SetWebsocketURL("", true, false)
if err != nil {
t.Fatal(err)
}
err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, false)
if err != nil {
t.Fatal(err)
}
// -- Already connected connect
// Attempt reconnect
err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, true)
if err != nil {
t.Fatal(err)
}
// -- initiate the reconnect which is usually handled by connection monitor
err = ws.Connect()
if err != nil {
t.Fatal(err)
}
err = ws.Connect()
if err == nil {
t.Fatal("should not connect, already connected")
t.Fatal("should already be connected")
}
// -- Normal shutdown
err = ws.Shutdown()
@@ -307,6 +483,12 @@ func TestSubscribeUnsubscribe(t *testing.T) {
t.Fatal("error cannot be nil")
}
// subscribe to nothing
err = ws.SubscribeToChannels(nil)
if err == nil {
t.Fatal("error cannot be nil")
}
err = ws.UnsubscribeChannels(subs)
if err != nil {
t.Fatal(err)
@@ -355,7 +537,26 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
ws.ShutdownC = make(chan struct{}, 1)
ws.exchangeName = "hello"
ws.trafficTimeout = 1
go ws.connectionMonitor()
ws.verbose = true
ws.Wg = &sync.WaitGroup{}
ws.connectionMonitor()
if !ws.IsConnectionMonitorRunning() {
t.Fatal("Should not have exited")
}
ws.connectionMonitor() // This one should exit
if !ws.IsConnectionMonitorRunning() {
t.Fatal("Should not have exited")
}
time.Sleep(time.Second)
if ws.IsConnectionMonitorRunning() {
t.Fatal("Should have exited")
}
ws.setConnectedStatus(true) // attempt shutdown when not enabled
ws.setConnectingStatus(true) // throw a spanner in the works
ws.connectionMonitor()
if !ws.IsConnectionMonitorRunning() {
t.Fatal("Should not have exited")
}
time.Sleep(time.Second)
if ws.IsConnectionMonitorRunning() {
t.Fatal("Should have exited")
@@ -824,7 +1025,21 @@ func TestFlushChannels(t *testing.T) {
currency.NewPair(currency.BTC, currency.AUD),
currency.NewPair(currency.BTC, currency.USDT),
}}
web := Websocket{enabled: true,
dodgyWs := Websocket{}
err := dodgyWs.FlushChannels()
if err == nil {
t.Fatal("error cannot be nil")
}
dodgyWs.setEnabled(true)
err = dodgyWs.FlushChannels()
if err == nil {
t.Fatal("error cannot be nil")
}
web := Websocket{
enabled: true,
connected: true,
connector: connect,
ShutdownC: make(chan struct{}),
@@ -833,13 +1048,20 @@ func TestFlushChannels(t *testing.T) {
Wg: new(sync.WaitGroup),
features: &protocol.Features{
// No features
}}
web.GenerateSubs = newgen.generateSubs
subs, err := web.GenerateSubs()
if err != nil {
t.Fatal(err)
},
trafficTimeout: time.Second * 30, // Added for when we utilise connect()
// in FlushChannels() so the traffic monitor doesn't time out and turn
// this to an unconnected state
}
web.subscriptions = subs
problemFunc := func() ([]ChannelSubscription, error) {
return nil, errors.New("problems")
}
noSub := func() ([]ChannelSubscription, error) {
return nil, nil
}
// Disable pair and flush system
newgen.EnabledPairs = []currency.Pair{
currency.NewPair(currency.BTC, currency.AUD)}
@@ -849,16 +1071,67 @@ func TestFlushChannels(t *testing.T) {
}
web.features.FullPayloadSubscribe = true
web.GenerateSubs = problemFunc
err = web.FlushChannels() // error on full subscribeToChannels
if err == nil {
t.Fatal("error cannot be nil")
}
web.GenerateSubs = noSub
err = web.FlushChannels() // No subs to sub
if err != nil {
t.Fatal(err)
}
web.GenerateSubs = newgen.generateSubs
subs, err := web.GenerateSubs()
if err != nil {
t.Fatal(err)
}
web.subscriptionMutex.Lock()
web.subscriptions = subs
web.subscriptionMutex.Unlock()
err = web.FlushChannels()
if err != nil {
t.Fatal(err)
}
web.features.FullPayloadSubscribe = false
web.features.Subscribe = true
web.GenerateSubs = problemFunc
err = web.FlushChannels()
if err == nil {
t.Fatal("error cannot be nil")
}
web.GenerateSubs = newgen.generateSubs
err = web.FlushChannels()
if err != nil {
t.Fatal(err)
}
web.subscriptionMutex.Lock()
web.subscriptions = []ChannelSubscription{
{
Channel: "match channel",
Currency: currency.NewPair(currency.BTC, currency.AUD),
},
{
Channel: "unsub channel",
Currency: currency.NewPair(currency.THETA, currency.USDT),
},
}
web.subscriptionMutex.Unlock()
err = web.FlushChannels()
if err != nil {
t.Fatal(err)
}
err = web.FlushChannels()
if err != nil {
t.Fatal(err)
}
web.setConnectedStatus(true)
web.features.Unsubscribe = true
err = web.FlushChannels()
@@ -903,6 +1176,30 @@ func TestEnable(t *testing.T) {
}
func TestSetupNewConnection(t *testing.T) {
var nonsenseWebsock *Websocket
err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"})
if err == nil {
t.Fatal("error cannot be nil")
}
nonsenseWebsock = &Websocket{}
err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"})
if err == nil {
t.Fatal("error cannot be nil")
}
nonsenseWebsock = &Websocket{exchangeName: "test"}
err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"})
if err == nil {
t.Fatal("error cannot be nil")
}
nonsenseWebsock.TrafficAlert = make(chan struct{})
err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"})
if err == nil {
t.Fatal("error cannot be nil")
}
web := Websocket{
connector: connect,
Wg: new(sync.WaitGroup),
@@ -912,7 +1209,7 @@ func TestSetupNewConnection(t *testing.T) {
ReadMessageErrors: make(chan error),
}
err := web.Setup(defaultSetup)
err = web.Setup(defaultSetup)
if err != nil {
t.Fatal(err)
}