engine/gRPC proxy: Fix mux regression and add test coverage (#1456)

* engine/gRPC proxy: Fix mux regression and enhance test coverage

* Use a temp dir for TLS creds and add credentials test tables

* Update GetRPCEndpoints grpcProxyName ListenAddr field

* Log unauthorised access attempts
This commit is contained in:
Adrian Gallagher
2024-02-05 15:22:54 +11:00
committed by GitHub
parent e0c6e118ed
commit d57fefbcfc
3 changed files with 174 additions and 6 deletions

View File

@@ -107,7 +107,7 @@ func (bot *Engine) GetRPCEndpoints() (map[string]RPCEndpoint, error) {
},
grpcProxyName: {
Started: bot.Settings.EnableGRPCProxy,
ListenAddr: "http://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
ListenAddr: "https://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
},
DeprecatedName: {
Started: bot.Settings.EnableDeprecatedRPC,

View File

@@ -172,12 +172,14 @@ func StartRPCServer(engine *Engine) {
// StartRPCRESTProxy starts a gRPC proxy
func (s *RPCServer) StartRPCRESTProxy() {
log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on http://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress)
log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on https://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress)
targetDir := utils.GetTLSDir(s.Settings.DataDir)
creds, err := credentials.NewClientTLSFromFile(filepath.Join(targetDir, "cert.pem"), "")
certFile := filepath.Join(targetDir, "cert.pem")
keyFile := filepath.Join(targetDir, "key.pem")
creds, err := credentials.NewClientTLSFromFile(certFile, "")
if err != nil {
log.Errorf(log.GRPCSys, "Unabled to start gRPC proxy. Err: %s\n", err)
log.Errorf(log.GRPCSys, "Unable to start gRPC proxy. Err: %s\n", err)
return
}
@@ -200,16 +202,31 @@ func (s *RPCServer) StartRPCRESTProxy() {
Addr: s.Config.RemoteControl.GRPC.GRPCProxyListenAddress,
ReadHeaderTimeout: time.Minute,
ReadTimeout: time.Minute,
Handler: s.authClient(mux),
}
if err = server.ListenAndServe(); err != nil {
log.Errorf(log.GRPCSys, "GRPC proxy failed to server: %s\n", err)
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
log.Errorf(log.GRPCSys, "gRPC proxy server failed to serve: %s\n", err)
return
}
}()
log.Debugln(log.GRPCSys, "gRPC proxy server started!")
}
func (s *RPCServer) authClient(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username != s.Config.RemoteControl.Username || password != s.Config.RemoteControl.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`)
http.Error(w, "Access denied", http.StatusUnauthorized)
log.Warnf(log.GRPCSys, "gRPC proxy server unauthorised access attempt. IP: %s Path: %s\n", r.RemoteAddr, r.URL.Path)
return
}
handler.ServeHTTP(w, r)
})
}
// GetInfo returns info about the current GoCryptoTrader session
func (s *RPCServer) GetInfo(_ context.Context, _ *gctrpc.GetInfoRequest) (*gctrpc.GetInfoResponse, error) {
rpcEndpoints, err := s.getRPCEndpoints()

View File

@@ -2,12 +2,20 @@ package engine
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
"testing"
@@ -16,6 +24,7 @@ import (
"github.com/gofrs/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/convert"
"github.com/thrasher-corp/gocryptotrader/common/key"
@@ -4139,3 +4148,145 @@ func TestGetOpenInterest(t *testing.T) {
_, err = s.GetOpenInterest(context.Background(), req)
assert.NoError(t, err)
}
func TestStartRPCRESTProxy(t *testing.T) {
t.Parallel()
tempDir := filepath.Join(os.TempDir(), "gct-grpc-proxy-test")
tempDirTLS := filepath.Join(tempDir, "tls")
t.Cleanup(func() {
assert.NoErrorf(t, os.RemoveAll(tempDir), "RemoveAll should not error, manual directory deletion required for TempDir: %s", tempDir)
})
if !assert.NoError(t, genCert(tempDirTLS), "genCert should not error") {
t.FailNow()
}
gRPCPort := rand.Intn(65535-42069) + 42069 //nolint:gosec // Don't require crypto/rand usage here
gRPCProxyPort := gRPCPort + 1
e := &Engine{
Config: &config.Config{
RemoteControl: config.RemoteControlConfig{
Username: "bobmarley",
Password: "Sup3rdup3rS3cr3t",
GRPC: config.GRPCConfig{
Enabled: true,
ListenAddress: "localhost:" + strconv.Itoa(gRPCPort),
GRPCProxyListenAddress: "localhost:" + strconv.Itoa(gRPCProxyPort),
},
},
},
Settings: Settings{
DataDir: tempDir,
CoreSettings: CoreSettings{EnableGRPCProxy: true},
},
}
fakeTime := time.Now().Add(-time.Hour)
e.uptime = fakeTime
StartRPCServer(e)
// Give the proxy time to start
time.Sleep(time.Millisecond * 500)
certFile := filepath.Join(tempDirTLS, "cert.pem")
caCert, err := os.ReadFile(certFile)
require.NoError(t, err, "ReadFile should not error")
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert)
require.True(t, ok, "AppendCertsFromPEM should return true")
client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}}}
for _, creds := range []struct {
testDescription string
username string
password string
}{
{"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"},
{"Valid username but invalid password", "bobmarley", "wrongpass"},
{"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"},
{"Invalid username and password despite glorious credentials", "bonk", "wif"},
} {
creds := creds
t.Run(creds.testDescription, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://localhost:"+strconv.Itoa(gRPCProxyPort)+"/v1/getinfo", http.NoBody)
require.NoError(t, err, "NewRequestWithContext should not error")
req.SetBasicAuth(creds.username, creds.password)
resp, err := client.Do(req)
require.NoError(t, err, "Do should not error")
defer resp.Body.Close()
if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" {
var info gctrpc.GetInfoResponse
err = json.NewDecoder(resp.Body).Decode(&info)
require.NoError(t, err, "Decode should not error")
uptimeDuration, err := time.ParseDuration(info.Uptime)
require.NoError(t, err, "ParseDuration should not error")
assert.InDelta(t, time.Since(fakeTime).Seconds(), uptimeDuration.Seconds(), 1.0, "Uptime should be within 1 second of the expected duration")
} else {
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err, "ReadAll should not error")
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "HTTP status code should be 401")
assert.Equal(t, "Access denied\n", string(respBody), "Response body should be 'Access denied\n'")
}
})
}
}
func TestRPCProxyAuthClient(t *testing.T) {
t.Parallel()
s := new(RPCServer)
s.Engine = &Engine{
Config: &config.Config{
RemoteControl: config.RemoteControlConfig{
Username: "bobmarley",
Password: "Sup3rdup3rS3cr3t",
},
},
}
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("MEOW"))
require.NoError(t, err, "Write should not error")
})
handler := s.authClient(dummyHandler)
for _, creds := range []struct {
testDescription string
username string
password string
}{
{"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"},
{"Valid username but invalid password", "bobmarley", "wrongpass"},
{"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"},
{"Invalid username and password despite glorious credentials", "bonk", "wif"},
} {
creds := creds
t.Run(creds.testDescription, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", http.NoBody)
require.NoError(t, err, "NewRequestWithContext should not error")
req.SetBasicAuth(creds.username, creds.password)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" {
assert.Equal(t, http.StatusOK, rr.Code, "HTTP status code should be 200")
assert.Equal(t, "MEOW", rr.Body.String(), "Response body should be 'MEOW'")
} else {
assert.Equal(t, http.StatusUnauthorized, rr.Code, "HTTP status code should be 401")
assert.Equal(t, "Access denied\n", rr.Body.String(), "Response body should be 'Access denied\n'")
}
})
}
}