package wshandler import ( "bytes" "compress/flate" "compress/gzip" "fmt" "io/ioutil" "net/http" "net/url" "sync" "time" "github.com/gorilla/websocket" "github.com/thrasher-/gocryptotrader/common" log "github.com/thrasher-/gocryptotrader/logger" ) // AddResponseWithID adds data to IDResponses with locks and a nil check func (w *WebsocketConnection) AddResponseWithID(id int64, data []byte) { w.Lock() defer w.Unlock() if w.IDResponses == nil { w.IDResponses = make(map[int64][]byte) } w.IDResponses[id] = data } // Dial sets proxy urls and then connects to the websocket func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error { if w.ProxyURL != "" { proxy, err := url.Parse(w.ProxyURL) if err != nil { return err } dialer.Proxy = http.ProxyURL(proxy) } var err error var conStatus *http.Response w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { return fmt.Errorf("%v %v %v Error: %v", w.URL, conStatus, conStatus.StatusCode, err) } return fmt.Errorf("%v Error: %v", w.URL, err) } return nil } // SendMessage the one true message request. Sends message to WS func (w *WebsocketConnection) SendMessage(data interface{}) error { w.Lock() defer w.Unlock() json, err := common.JSONEncode(data) if err != nil { return err } if w.Verbose { log.Debugf("%v sending message to websocket %v", w.ExchangeName, string(json)) } if w.RateLimit > 0 { time.Sleep(time.Duration(w.RateLimit) * time.Millisecond) } return w.Connection.WriteMessage(websocket.TextMessage, json) } // SendMessageReturnResponse will send a WS message to the connection // It will then run a goroutine to await a JSON response // If there is no response it will return an error func (w *WebsocketConnection) SendMessageReturnResponse(id int64, request interface{}) ([]byte, error) { err := w.SendMessage(request) if err != nil { return nil, err } var wg sync.WaitGroup wg.Add(1) go w.WaitForResult(id, &wg) defer func() { delete(w.IDResponses, id) }() wg.Wait() if _, ok := w.IDResponses[id]; !ok { return nil, fmt.Errorf("timeout waiting for response with ID %v", id) } return w.IDResponses[id], nil } // WaitForResult will keep checking w.IDResponses for a response ID // If the timer expires, it will return without func (w *WebsocketConnection) WaitForResult(id int64, wg *sync.WaitGroup) { defer wg.Done() timer := time.NewTimer(w.ResponseMaxLimit) for { select { case <-timer.C: return default: w.Lock() for k := range w.IDResponses { if k == id { w.Unlock() return } } w.Unlock() time.Sleep(w.ResponseCheckTimeout) } } } // ReadMessage reads messages, can handle text, gzip and binary func (w *WebsocketConnection) ReadMessage() (WebsocketResponse, error) { mType, resp, err := w.Connection.ReadMessage() if err != nil { return WebsocketResponse{}, err } var standardMessage []byte switch mType { case websocket.TextMessage: standardMessage = resp case websocket.BinaryMessage: standardMessage, err = w.parseBinaryResponse(resp) if err != nil { return WebsocketResponse{}, err } } if w.Verbose { log.Debugf("%v Websocket message received: %v", w.ExchangeName, string(standardMessage)) } return WebsocketResponse{Raw: standardMessage, Type: mType}, nil } // parseBinaryResponse parses a websocket binaray response into a usable byte array func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { var standardMessage []byte var err error // Detect GZIP if resp[0] == 31 && resp[1] == 139 { b := bytes.NewReader(resp) var gReader *gzip.Reader gReader, err = gzip.NewReader(b) if err != nil { return standardMessage, err } standardMessage, err = ioutil.ReadAll(gReader) if err != nil { return standardMessage, err } err = gReader.Close() if err != nil { return standardMessage, err } } else { reader := flate.NewReader(bytes.NewReader(resp)) standardMessage, err = ioutil.ReadAll(reader) if err != nil { return standardMessage, err } err = reader.Close() if err != nil { return standardMessage, err } } return standardMessage, nil } // GenerateMessageID Creates a messageID to checkout func (w *WebsocketConnection) GenerateMessageID(useNano bool) int64 { if useNano { return time.Now().UnixNano() } return time.Now().Unix() }