mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-14 07:26:47 +00:00
engine/gRPC proxy: Fix mux regression and add test coverage (#1456)
* engine/gRPC proxy: Fix mux regression and enhance test coverage * Use a temp dir for TLS creds and add credentials test tables * Update GetRPCEndpoints grpcProxyName ListenAddr field * Log unauthorised access attempts
This commit is contained in:
@@ -107,7 +107,7 @@ func (bot *Engine) GetRPCEndpoints() (map[string]RPCEndpoint, error) {
|
||||
},
|
||||
grpcProxyName: {
|
||||
Started: bot.Settings.EnableGRPCProxy,
|
||||
ListenAddr: "http://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
|
||||
ListenAddr: "https://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
|
||||
},
|
||||
DeprecatedName: {
|
||||
Started: bot.Settings.EnableDeprecatedRPC,
|
||||
|
||||
@@ -172,12 +172,14 @@ func StartRPCServer(engine *Engine) {
|
||||
|
||||
// StartRPCRESTProxy starts a gRPC proxy
|
||||
func (s *RPCServer) StartRPCRESTProxy() {
|
||||
log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on http://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress)
|
||||
log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on https://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress)
|
||||
|
||||
targetDir := utils.GetTLSDir(s.Settings.DataDir)
|
||||
creds, err := credentials.NewClientTLSFromFile(filepath.Join(targetDir, "cert.pem"), "")
|
||||
certFile := filepath.Join(targetDir, "cert.pem")
|
||||
keyFile := filepath.Join(targetDir, "key.pem")
|
||||
creds, err := credentials.NewClientTLSFromFile(certFile, "")
|
||||
if err != nil {
|
||||
log.Errorf(log.GRPCSys, "Unabled to start gRPC proxy. Err: %s\n", err)
|
||||
log.Errorf(log.GRPCSys, "Unable to start gRPC proxy. Err: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -200,16 +202,31 @@ func (s *RPCServer) StartRPCRESTProxy() {
|
||||
Addr: s.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
ReadTimeout: time.Minute,
|
||||
Handler: s.authClient(mux),
|
||||
}
|
||||
|
||||
if err = server.ListenAndServe(); err != nil {
|
||||
log.Errorf(log.GRPCSys, "GRPC proxy failed to server: %s\n", err)
|
||||
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
|
||||
log.Errorf(log.GRPCSys, "gRPC proxy server failed to serve: %s\n", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugln(log.GRPCSys, "gRPC proxy server started!")
|
||||
}
|
||||
|
||||
func (s *RPCServer) authClient(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || username != s.Config.RemoteControl.Username || password != s.Config.RemoteControl.Password {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`)
|
||||
http.Error(w, "Access denied", http.StatusUnauthorized)
|
||||
log.Warnf(log.GRPCSys, "gRPC proxy server unauthorised access attempt. IP: %s Path: %s\n", r.RemoteAddr, r.URL.Path)
|
||||
return
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// GetInfo returns info about the current GoCryptoTrader session
|
||||
func (s *RPCServer) GetInfo(_ context.Context, _ *gctrpc.GetInfoRequest) (*gctrpc.GetInfoResponse, error) {
|
||||
rpcEndpoints, err := s.getRPCEndpoints()
|
||||
|
||||
@@ -2,12 +2,20 @@ package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -16,6 +24,7 @@ import (
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/convert"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
@@ -4139,3 +4148,145 @@ func TestGetOpenInterest(t *testing.T) {
|
||||
_, err = s.GetOpenInterest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStartRPCRESTProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := filepath.Join(os.TempDir(), "gct-grpc-proxy-test")
|
||||
tempDirTLS := filepath.Join(tempDir, "tls")
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoErrorf(t, os.RemoveAll(tempDir), "RemoveAll should not error, manual directory deletion required for TempDir: %s", tempDir)
|
||||
})
|
||||
|
||||
if !assert.NoError(t, genCert(tempDirTLS), "genCert should not error") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
gRPCPort := rand.Intn(65535-42069) + 42069 //nolint:gosec // Don't require crypto/rand usage here
|
||||
gRPCProxyPort := gRPCPort + 1
|
||||
|
||||
e := &Engine{
|
||||
Config: &config.Config{
|
||||
RemoteControl: config.RemoteControlConfig{
|
||||
Username: "bobmarley",
|
||||
Password: "Sup3rdup3rS3cr3t",
|
||||
GRPC: config.GRPCConfig{
|
||||
Enabled: true,
|
||||
ListenAddress: "localhost:" + strconv.Itoa(gRPCPort),
|
||||
GRPCProxyListenAddress: "localhost:" + strconv.Itoa(gRPCProxyPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
Settings: Settings{
|
||||
DataDir: tempDir,
|
||||
CoreSettings: CoreSettings{EnableGRPCProxy: true},
|
||||
},
|
||||
}
|
||||
|
||||
fakeTime := time.Now().Add(-time.Hour)
|
||||
e.uptime = fakeTime
|
||||
|
||||
StartRPCServer(e)
|
||||
|
||||
// Give the proxy time to start
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
certFile := filepath.Join(tempDirTLS, "cert.pem")
|
||||
caCert, err := os.ReadFile(certFile)
|
||||
require.NoError(t, err, "ReadFile should not error")
|
||||
caCertPool := x509.NewCertPool()
|
||||
ok := caCertPool.AppendCertsFromPEM(caCert)
|
||||
require.True(t, ok, "AppendCertsFromPEM should return true")
|
||||
client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}}}
|
||||
|
||||
for _, creds := range []struct {
|
||||
testDescription string
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
{"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"},
|
||||
{"Valid username but invalid password", "bobmarley", "wrongpass"},
|
||||
{"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"},
|
||||
{"Invalid username and password despite glorious credentials", "bonk", "wif"},
|
||||
} {
|
||||
creds := creds
|
||||
t.Run(creds.testDescription, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://localhost:"+strconv.Itoa(gRPCProxyPort)+"/v1/getinfo", http.NoBody)
|
||||
require.NoError(t, err, "NewRequestWithContext should not error")
|
||||
req.SetBasicAuth(creds.username, creds.password)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err, "Do should not error")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" {
|
||||
var info gctrpc.GetInfoResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&info)
|
||||
require.NoError(t, err, "Decode should not error")
|
||||
|
||||
uptimeDuration, err := time.ParseDuration(info.Uptime)
|
||||
require.NoError(t, err, "ParseDuration should not error")
|
||||
assert.InDelta(t, time.Since(fakeTime).Seconds(), uptimeDuration.Seconds(), 1.0, "Uptime should be within 1 second of the expected duration")
|
||||
} else {
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err, "ReadAll should not error")
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "HTTP status code should be 401")
|
||||
assert.Equal(t, "Access denied\n", string(respBody), "Response body should be 'Access denied\n'")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCProxyAuthClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := new(RPCServer)
|
||||
s.Engine = &Engine{
|
||||
Config: &config.Config{
|
||||
RemoteControl: config.RemoteControlConfig{
|
||||
Username: "bobmarley",
|
||||
Password: "Sup3rdup3rS3cr3t",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte("MEOW"))
|
||||
require.NoError(t, err, "Write should not error")
|
||||
})
|
||||
|
||||
handler := s.authClient(dummyHandler)
|
||||
|
||||
for _, creds := range []struct {
|
||||
testDescription string
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
{"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"},
|
||||
{"Valid username but invalid password", "bobmarley", "wrongpass"},
|
||||
{"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"},
|
||||
{"Invalid username and password despite glorious credentials", "bonk", "wif"},
|
||||
} {
|
||||
creds := creds
|
||||
t.Run(creds.testDescription, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", http.NoBody)
|
||||
require.NoError(t, err, "NewRequestWithContext should not error")
|
||||
req.SetBasicAuth(creds.username, creds.password)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" {
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "HTTP status code should be 200")
|
||||
assert.Equal(t, "MEOW", rr.Body.String(), "Response body should be 'MEOW'")
|
||||
} else {
|
||||
assert.Equal(t, http.StatusUnauthorized, rr.Code, "HTTP status code should be 401")
|
||||
assert.Equal(t, "Access denied\n", rr.Body.String(), "Response body should be 'Access denied\n'")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user