diff --git a/connchecker/connchecker.go b/connchecker/connchecker.go index 03b725f7..2c8e338c 100644 --- a/connchecker/connchecker.go +++ b/connchecker/connchecker.go @@ -2,6 +2,7 @@ package connchecker import ( "net" + "strings" "sync" "time" @@ -10,7 +11,14 @@ import ( // DefaultCheckInterval is a const that defines the amount of time between // checking if the connection is lost -const DefaultCheckInterval = time.Second +const ( + DefaultCheckInterval = time.Second + + ConnRe = "Internet connectivity re-established" + ConnLost = "Internet connectivity lost" + ConnFound = "Internet connectivity found" + ConnNotFound = "No internet connectivity" +) // Default check lists var ( @@ -19,8 +27,8 @@ var ( ) // New returns a new connection checker, if no values set it will default it out -func New(dnsList, domainList []string, checkInterval time.Duration) *Checker { - c := &Checker{} +func New(dnsList, domainList []string, checkInterval time.Duration) (*Checker, error) { + c := new(Checker) if len(dnsList) == 0 { c.DNSList = DefaultDNSList } else { @@ -39,8 +47,23 @@ func New(dnsList, domainList []string, checkInterval time.Duration) *Checker { c.CheckInterval = checkInterval } - go c.Monitor() - return c + err := c.initialCheck() + if err != nil { + return nil, err + } + + if c.connected { + log.Debug(ConnFound) + } else { + log.Warnf(ConnNotFound) + } + + c.shutdown = make(chan struct{}, 1) + var wg sync.WaitGroup + wg.Add(1) + go c.Monitor(&wg) + wg.Wait() + return c, nil } // Checker defines a struct to determine connectivity to the interwebs @@ -56,16 +79,16 @@ type Checker struct { // Shutdown cleanly shutsdown monitor routine func (c *Checker) Shutdown() { - c.shutdown <- struct{}{} + close(c.shutdown) c.wg.Wait() } // Monitor determines internet connectivity via a DNS lookup -func (c *Checker) Monitor() { +func (c *Checker) Monitor(wg *sync.WaitGroup) { c.wg.Add(1) tick := time.NewTicker(time.Second) defer func() { tick.Stop(); c.wg.Done() }() - c.connectionTest() + wg.Done() for { select { case <-tick.C: @@ -76,15 +99,45 @@ func (c *Checker) Monitor() { } } +// initialCheck starts an initial connection check +func (c *Checker) initialCheck() error { + var connected bool + for i := range c.DNSList { + err := c.CheckDNS(c.DNSList[i]) + if err != nil { + if strings.Contains(err.Error(), "unrecognized address") || + strings.Contains(err.Error(), "invalid address") { + return err + } + continue + } + if !connected { + connected = true + } + } + + for i := range c.DomainList { + err := c.CheckHost(c.DomainList[i]) + if err != nil { + continue + } + if !connected { + connected = true + } + } + c.connected = connected + return nil +} + // ConnectionTest determines if a connection to the internet is available by // iterating over a set list of dns ip and popular domains func (c *Checker) connectionTest() { for i := range c.DNSList { - _, err := net.LookupAddr(c.DNSList[i]) + err := c.CheckDNS(c.DNSList[i]) if err == nil { c.Lock() if !c.connected { - log.Warnf("Internet connectivity re-established") + log.Debug(ConnRe) c.connected = true } c.Unlock() @@ -93,11 +146,11 @@ func (c *Checker) connectionTest() { } for i := range c.DomainList { - _, err := net.LookupHost(c.DomainList[i]) + err := c.CheckHost(c.DomainList[i]) if err == nil { c.Lock() if !c.connected { - log.Warnf("Internet connectivity re-established") + log.Debug(ConnRe) c.connected = true } c.Unlock() @@ -107,12 +160,24 @@ func (c *Checker) connectionTest() { c.Lock() if c.connected { - log.Warnf("Internet connectivity lost") + log.Warn(ConnLost) c.connected = false } c.Unlock() } +// CheckDNS checks current dns for connectivity +func (c *Checker) CheckDNS(dns string) error { + _, err := net.LookupAddr(dns) + return err +} + +// CheckHost checks current host name for connectivity +func (c *Checker) CheckHost(host string) error { + _, err := net.LookupHost(host) + return err +} + // IsConnected returns if there is internet connectivity func (c *Checker) IsConnected() bool { c.Lock() diff --git a/connchecker/connchecker_test.go b/connchecker/connchecker_test.go new file mode 100644 index 00000000..7b1b27b6 --- /dev/null +++ b/connchecker/connchecker_test.go @@ -0,0 +1,37 @@ +package connchecker + +import ( + "testing" +) + +func TestConnection(t *testing.T) { + faultyDomain := []string{"faultyIP"} + faultyHost := []string{"faultyHost"} + _, err := New(faultyDomain, nil, 100000) + if err == nil { + t.Fatal("Test Failed - New error cannot be nil") + } + + _, err = New(DefaultDNSList, nil, 100000) + if err != nil { + t.Fatal("Test Failed - New error", err) + } + + _, err = New(nil, faultyHost, 100000) + if err != nil { + t.Fatal("Test Failed - New error cannot be nil", err) + } + + c, err := New(nil, nil, 0) + if err != nil { + t.Fatal("Test Failed - New error", err) + } + + if !c.IsConnected() { + t.Log("Test - No internet connection found") + } else { + t.Log("Test - Internet connection found") + } + + c.Shutdown() +} diff --git a/exchanges/okgroup/okgroup_websocket.go b/exchanges/okgroup/okgroup_websocket.go index 319c7da3..d86c5ab6 100644 --- a/exchanges/okgroup/okgroup_websocket.go +++ b/exchanges/okgroup/okgroup_websocket.go @@ -190,14 +190,19 @@ func (o *OKGroup) WsConnect() error { log.Debugf("Successful connection to %v", o.Websocket.GetWebsocketURL()) } - go o.WsHandleData() - go o.wsPingHandler() + var wg sync.WaitGroup + wg.Add(2) + go o.WsHandleData(&wg) + go o.wsPingHandler(&wg) err = o.WsSubscribeToDefaults() if err != nil { return fmt.Errorf("error: Could not subscribe to the OKEX websocket %s", err) } + + // Ensures that we start the routines and we dont race when shutdown occurs + wg.Wait() return nil } @@ -246,12 +251,14 @@ func (o *OKGroup) WsReadData() (exchange.WebsocketResponse, error) { } // wsPingHandler sends a message "ping" every 27 to maintain the connection to the websocket -func (o *OKGroup) wsPingHandler() { +func (o *OKGroup) wsPingHandler(wg *sync.WaitGroup) { o.Websocket.Wg.Add(1) defer o.Websocket.Wg.Done() ticker := time.NewTicker(time.Second * 27) + wg.Done() + for { select { case <-o.Websocket.ShutdownC: @@ -271,7 +278,7 @@ func (o *OKGroup) wsPingHandler() { } // WsHandleData handles the read data from the websocket connection -func (o *OKGroup) WsHandleData() { +func (o *OKGroup) WsHandleData(wg *sync.WaitGroup) { o.Websocket.Wg.Add(1) defer func() { err := o.WebsocketConn.Close() @@ -282,6 +289,8 @@ func (o *OKGroup) WsHandleData() { o.Websocket.Wg.Done() }() + wg.Done() + for { select { case <-o.Websocket.ShutdownC: diff --git a/main.go b/main.go index e9a7db6c..8b43d399 100644 --- a/main.go +++ b/main.go @@ -134,9 +134,12 @@ func main() { } // Sets up internet connectivity monitor - bot.connectivity = connchecker.New(bot.config.ConnectionMonitor.DNSList, + bot.connectivity, err = connchecker.New(bot.config.ConnectionMonitor.DNSList, bot.config.ConnectionMonitor.PublicDomainList, bot.config.ConnectionMonitor.CheckInterval) + if err != nil { + log.Fatalf("Connectivity checker failure: %s", err) + } AdjustGoMaxProcs() log.Debugf("Bot '%s' started.\n", bot.config.Name)