diff --git a/main.go b/main.go index f9b1ffcf..04e81e3b 100644 --- a/main.go +++ b/main.go @@ -123,9 +123,6 @@ func main() { bot.portfolio.SeedPortfolio(bot.config.Portfolio) SeedExchangeAccountInfo(GetAllEnabledExchangeAccountInfo().Data) go portfolio.StartPortfolioWatcher() - - log.Println("Starting websocket handler.") - go WebsocketHandler() go TickerUpdaterRoutine() go OrderbookUpdaterRoutine() @@ -135,8 +132,18 @@ func main() { "HTTP Webserver support enabled. Listen URL: http://%s:%d/\n", common.ExtractHost(listenAddr), common.ExtractPort(listenAddr), ) + router := NewRouter(bot.exchanges) - log.Fatal(http.ListenAndServe(listenAddr, router)) + go func() { + err = http.ListenAndServe(listenAddr, router) + if err != nil { + log.Fatal(err) + } + }() + + log.Println("HTTP Webserver started successfully.") + log.Println("Starting websocket handler.") + StartWebsocketHandler() } else { log.Println("HTTP RESTful Webserver support disabled.") } diff --git a/routines.go b/routines.go index df040dba..cd3fc700 100644 --- a/routines.go +++ b/routines.go @@ -206,7 +206,7 @@ func TickerUpdaterRoutine() { result, err = exch.GetTickerPrice(c, assetType) } printTickerSummary(result, c, assetType, exchangeName, err) - if err == nil { + if bot.config.Webserver.Enabled && err == nil { relayWebsocketEvent(result, "ticker_update", assetType, exchangeName) } } @@ -254,7 +254,7 @@ func OrderbookUpdaterRoutine() { processOrderbook := func(exch exchange.IBotExchange, c pair.CurrencyPair, assetType string) { result, err := exch.UpdateOrderbook(c, assetType) printOrderbookSummary(result, c, assetType, exchangeName, err) - if err == nil { + if bot.config.Webserver.Enabled && err == nil { relayWebsocketEvent(result, "orderbook_update", assetType, exchangeName) } } diff --git a/websocket.go b/websocket.go index 51a4b4fe..473d34a9 100644 --- a/websocket.go +++ b/websocket.go @@ -1,10 +1,8 @@ package main import ( - "errors" "log" "net/http" - "time" "github.com/gorilla/websocket" "github.com/thrasher-/gocryptotrader/common" @@ -17,12 +15,39 @@ const ( WebsocketResponseSuccess = "OK" ) +var ( + wsHub *WebsocketHub + wsHubStarted bool +) + +type wsCommandHandler func(client *WebsocketClient, data interface{}) error + +var wsHandlers = map[string]wsCommandHandler{ + "getconfig": wsGetConfig, + "saveconfig": wsSaveConfig, + "getaccountinfo": wsGetAccountInfo, + "gettickers": wsGetTickers, + "getticker": wsGetTicker, + "getorderbooks": wsGetOrderbooks, + "getorderbook": wsGetOrderbook, + "getexchangerates": wsGetExchangeRates, + "getportfolio": wsGetPortfolio, +} + // WebsocketClient stores information related to the websocket client type WebsocketClient struct { - ID int + Hub *WebsocketHub Conn *websocket.Conn - LastRecv time.Time Authenticated bool + Send chan []byte +} + +// WebsocketHub stores the data for managing websocket clients +type WebsocketHub struct { + Clients map[*WebsocketClient]bool + Broadcast chan []byte + Register chan *WebsocketClient + Unregister chan *WebsocketClient } // WebsocketEvent is the struct used for websocket events @@ -48,17 +73,215 @@ type WebsocketOrderbookTickerRequest struct { AssetType string `json:"assetType"` } -// WebsocketClientHub stores an array of websocket clients -var WebsocketClientHub []WebsocketClient +// WebsocketAuth is a struct used for +type WebsocketAuth struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// NewWebsocketHub Creates a new websocket hub +func NewWebsocketHub() *WebsocketHub { + return &WebsocketHub{ + Broadcast: make(chan []byte), + Register: make(chan *WebsocketClient), + Unregister: make(chan *WebsocketClient), + Clients: make(map[*WebsocketClient]bool), + } +} + +func (h *WebsocketHub) run() { + for { + select { + case client := <-h.Register: + h.Clients[client] = true + case client := <-h.Unregister: + if _, ok := h.Clients[client]; ok { + log.Printf("websocket: disconnected client") + delete(h.Clients, client) + close(client.Send) + } + case message := <-h.Broadcast: + for client := range h.Clients { + select { + case client.Send <- message: + default: + log.Printf("websocket: disconnected client") + close(client.Send) + delete(h.Clients, client) + } + } + } + } +} + +// SendWebsocketMessage sends a websocket event to the client +func (c *WebsocketClient) SendWebsocketMessage(evt interface{}) error { + data, err := common.JSONEncode(evt) + if err != nil { + log.Printf("websocket: failed to send message: %s", err) + return err + } + + c.Send <- data + return nil +} + +func (c *WebsocketClient) read() { + defer func() { + c.Hub.Unregister <- c + c.Conn.Close() + }() + + for { + msgType, message, err := c.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("websocket: client disconnected, err: %s", err) + } + break + } + + if msgType == websocket.TextMessage { + var evt WebsocketEvent + err := common.JSONDecode(message, &evt) + if err != nil { + log.Printf("websocket: failed to decode JSON sent from client %s", err) + break + } + + if evt.Event == "" { + log.Printf("websocket: client sent a blank event, disconnecting") + break + } + + dataJSON, err := common.JSONEncode(evt.Data) + if err != nil { + log.Printf("websocket: client sent data we couldn't JSON decode") + break + } + + req := common.StringToLower(evt.Event) + log.Printf("websocket: request received: %s", req) + + if !c.Authenticated && evt.Event != "auth" { + wsResp := WebsocketEventResponse{ + Event: "auth", + Error: "you must authenticate first", + } + + c.SendWebsocketMessage(wsResp) + log.Printf("websocket: client didn't auth, disconnecting!") + break + } else if !c.Authenticated && evt.Event == "auth" { + var auth WebsocketAuth + err = common.JSONDecode(dataJSON, &auth) + if err != nil { + log.Println(err) + continue + } + hashPW := common.HexEncodeToString(common.GetSHA256([]byte(bot.config.Webserver.AdminPassword))) + if auth.Username == bot.config.Webserver.AdminUsername && auth.Password == hashPW { + c.Authenticated = true + wsResp := WebsocketEventResponse{ + Event: "auth", + Data: WebsocketResponseSuccess, + } + c.SendWebsocketMessage(wsResp) + log.Println("websocket: client authenticated successfully") + continue + } else { + wsResp := WebsocketEventResponse{ + Event: "auth", + Error: "invalid username/password", + } + c.SendWebsocketMessage(wsResp) + log.Printf("websocket: client sent wrong username/password") + break + } + } + + result, ok := wsHandlers[req] + if !ok { + log.Printf("websocket: unsupported event") + continue + } + + err = result(c, dataJSON) + if err != nil { + log.Printf("websocket: request %s failed. Error %s", evt.Event, err) + continue + } + } + } +} + +func (c *WebsocketClient) write() { + defer func() { + c.Conn.Close() + }() + for { + select { + case message, ok := <-c.Send: + if !ok { + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + log.Printf("websocket: hub closed the channel") + return + } + + w, err := c.Conn.NextWriter(websocket.TextMessage) + if err != nil { + log.Printf("websocket: failed to create new io.writeCloser: %s", err) + return + } + w.Write(message) + + // Add queued chat messages to the current websocket message + n := len(c.Send) + for i := 0; i < n; i++ { + w.Write(<-c.Send) + } + + if err := w.Close(); err != nil { + log.Printf("websocket: failed to close io.WriteCloser: %s", err) + return + } + } + } +} + +// StartWebsocketHandler starts the websocket hub and routine which +// handles clients +func StartWebsocketHandler() { + if !wsHubStarted { + wsHubStarted = true + wsHub = NewWebsocketHub() + go wsHub.run() + } +} + +// BroadcastWebsocketMessage meow +func BroadcastWebsocketMessage(evt WebsocketEvent) error { + data, err := common.JSONEncode(evt) + if err != nil { + return err + } + + wsHub.Broadcast <- data + return nil +} // WebsocketClientHandler upgrades the HTTP connection to a websocket // compatible one func WebsocketClientHandler(w http.ResponseWriter, r *http.Request) { + if !wsHubStarted { + StartWebsocketHandler() + } + connectionLimit := bot.config.Webserver.WebsocketConnectionLimit - numClients := len(WebsocketClientHub) + numClients := len(wsHub.Clients) if numClients >= connectionLimit { - log.Printf("Websocket client rejected due to websocket client limit reached. Number of clients %d. Limit %d.", + log.Printf("websocket: client rejected due to websocket client limit reached. Number of clients %d. Limit %d.", numClients, connectionLimit) w.WriteHeader(http.StatusForbidden) return @@ -75,83 +298,30 @@ func WebsocketClientHandler(w http.ResponseWriter, r *http.Request) { upgrader.CheckOrigin = func(r *http.Request) bool { return true } } - newClient := WebsocketClient{ - ID: len(WebsocketClientHub), - } - conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } - newClient.Conn = conn - WebsocketClientHub = append(WebsocketClientHub, newClient) - numClients++ - log.Printf("New websocket client connected. Connected clients: %d. Limit %d.", - numClients, connectionLimit) + client := &WebsocketClient{Hub: wsHub, Conn: conn, Send: make(chan []byte, 1024)} + client.Hub.Register <- client + log.Printf("websocket: client connected. Connected clients: %d. Limit %d.", + numClients+1, connectionLimit) + + go client.read() + go client.write() } -// DisconnectWebsocketClient disconnects a websocket client -func DisconnectWebsocketClient(id int, err error) { - for i := range WebsocketClientHub { - if WebsocketClientHub[i].ID == id { - WebsocketClientHub[i].Conn.Close() - WebsocketClientHub = append(WebsocketClientHub[:i], WebsocketClientHub[i+1:]...) - log.Printf("Disconnected Websocket client, error: %s", err) - return - } - } -} - -// SendWebsocketMessage sends a websocket message to a specific client -func SendWebsocketMessage(id int, data interface{}) error { - for _, x := range WebsocketClientHub { - if x.ID == id { - return x.Conn.WriteJSON(data) - } - } - return nil -} - -// BroadcastWebsocketMessage broadcasts a websocket event message to all -// websocket clients -func BroadcastWebsocketMessage(evt WebsocketEvent) error { - for _, x := range WebsocketClientHub { - x.Conn.WriteJSON(evt) - } - return nil -} - -// WebsocketAuth is a struct used for -type WebsocketAuth struct { - Username string `json:"username"` - Password string `json:"password"` -} - -type wsCommandHandler func(wsClient *websocket.Conn, data interface{}) error - -var wsHandlers = map[string]wsCommandHandler{ - "getconfig": wsGetConfig, - "saveconfig": wsSaveConfig, - "getaccountinfo": wsGetAccountInfo, - "gettickers": wsGetTickers, - "getticker": wsGetTicker, - "getorderbooks": wsGetOrderbooks, - "getorderbook": wsGetOrderbook, - "getexchangerates": wsGetExchangeRates, - "getportfolio": wsGetPortfolio, -} - -func wsGetConfig(wsClient *websocket.Conn, data interface{}) error { +func wsGetConfig(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetConfig", Data: bot.config, } - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsSaveConfig(wsClient *websocket.Conn, data interface{}) error { +func wsSaveConfig(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "SaveConfig", } @@ -159,44 +329,46 @@ func wsSaveConfig(wsClient *websocket.Conn, data interface{}) error { err := common.JSONDecode(data.([]byte), &cfg) if err != nil { wsResp.Error = err.Error() - err = wsClient.WriteJSON(wsResp) + err = client.SendWebsocketMessage(wsResp) if err != nil { return err } + return err } err = bot.config.UpdateConfig(bot.configFile, cfg) if err != nil { wsResp.Error = err.Error() - err = wsClient.WriteJSON(wsResp) + err = client.SendWebsocketMessage(wsResp) if err != nil { return err } + return err } SetupExchanges() wsResp.Data = WebsocketResponseSuccess - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetAccountInfo(wsClient *websocket.Conn, data interface{}) error { +func wsGetAccountInfo(client *WebsocketClient, data interface{}) error { accountInfo := GetAllEnabledExchangeAccountInfo() wsResp := WebsocketEventResponse{ Event: "GetAccountInfo", Data: accountInfo, } - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetTickers(wsClient *websocket.Conn, data interface{}) error { +func wsGetTickers(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetTickers", } wsResp.Data = GetAllActiveTickers() - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetTicker(wsClient *websocket.Conn, data interface{}) error { +func wsGetTicker(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetTicker", } @@ -204,7 +376,7 @@ func wsGetTicker(wsClient *websocket.Conn, data interface{}) error { err := common.JSONDecode(data.([]byte), &tickerReq) if err != nil { wsResp.Error = err.Error() - wsClient.WriteJSON(wsResp) + client.SendWebsocketMessage(wsResp) return err } @@ -213,22 +385,22 @@ func wsGetTicker(wsClient *websocket.Conn, data interface{}) error { if err != nil { wsResp.Error = err.Error() - wsClient.WriteJSON(wsResp) + client.SendWebsocketMessage(wsResp) return err } wsResp.Data = result - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetOrderbooks(wsClient *websocket.Conn, data interface{}) error { +func wsGetOrderbooks(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetOrderbooks", } wsResp.Data = GetAllActiveOrderbooks() - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetOrderbook(wsClient *websocket.Conn, data interface{}) error { +func wsGetOrderbook(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetOrderbook", } @@ -236,7 +408,7 @@ func wsGetOrderbook(wsClient *websocket.Conn, data interface{}) error { err := common.JSONDecode(data.([]byte), &orderbookReq) if err != nil { wsResp.Error = err.Error() - wsClient.WriteJSON(wsResp) + client.SendWebsocketMessage(wsResp) return err } @@ -245,14 +417,14 @@ func wsGetOrderbook(wsClient *websocket.Conn, data interface{}) error { if err != nil { wsResp.Error = err.Error() - wsClient.WriteJSON(wsResp) + client.SendWebsocketMessage(wsResp) return err } wsResp.Data = result - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetExchangeRates(wsClient *websocket.Conn, data interface{}) error { +func wsGetExchangeRates(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetExchangeRates", } @@ -261,89 +433,13 @@ func wsGetExchangeRates(wsClient *websocket.Conn, data interface{}) error { } else { wsResp.Data = currency.CurrencyStoreFixer } - return wsClient.WriteJSON(wsResp) + return client.SendWebsocketMessage(wsResp) } -func wsGetPortfolio(wsClient *websocket.Conn, data interface{}) error { +func wsGetPortfolio(client *WebsocketClient, data interface{}) error { wsResp := WebsocketEventResponse{ Event: "GetPortfolio", } wsResp.Data = bot.portfolio.GetPortfolioSummary() - return wsClient.WriteJSON(wsResp) -} - -// WebsocketHandler Handles websocket client requests -func WebsocketHandler() { - for { - for x := range WebsocketClientHub { - var evt WebsocketEvent - err := WebsocketClientHub[x].Conn.ReadJSON(&evt) - if err != nil { - DisconnectWebsocketClient(x, err) - continue - } - - if evt.Event == "" { - DisconnectWebsocketClient(x, errors.New("Websocket client sent data we did not understand")) - continue - } - - dataJSON, err := common.JSONEncode(evt.Data) - if err != nil { - log.Println(err) - continue - } - - req := common.StringToLower(evt.Event) - log.Printf("Websocket req: %s", req) - - if !WebsocketClientHub[x].Authenticated && evt.Event != "auth" { - wsResp := WebsocketEventResponse{ - Event: "auth", - Error: "you must authenticate first", - } - SendWebsocketMessage(x, wsResp) - DisconnectWebsocketClient(x, errors.New("Websocket client did not auth")) - continue - } else if !WebsocketClientHub[x].Authenticated && evt.Event == "auth" { - var auth WebsocketAuth - err = common.JSONDecode(dataJSON, &auth) - if err != nil { - log.Println(err) - continue - } - hashPW := common.HexEncodeToString(common.GetSHA256([]byte(bot.config.Webserver.AdminPassword))) - if auth.Username == bot.config.Webserver.AdminUsername && auth.Password == hashPW { - WebsocketClientHub[x].Authenticated = true - wsResp := WebsocketEventResponse{ - Event: "auth", - Data: WebsocketResponseSuccess, - } - SendWebsocketMessage(x, wsResp) - log.Println("Websocket client authenticated successfully") - continue - } else { - wsResp := WebsocketEventResponse{ - Event: "auth", - Error: "invalid username/password", - } - SendWebsocketMessage(x, wsResp) - DisconnectWebsocketClient(x, errors.New("Websocket client sent wrong username/password")) - continue - } - } - result, ok := wsHandlers[req] - if !ok { - log.Printf("Websocket unsupported event") - continue - } - - err = result(WebsocketClientHub[x].Conn, dataJSON) - if err != nil { - log.Printf("Websocket request %s failed. Error %s", evt.Event, err) - continue - } - } - time.Sleep(time.Millisecond) - } + return client.SendWebsocketMessage(wsResp) }