diff --git a/engine/helpers.go b/engine/helpers.go index 4c531f86..1700b8cb 100644 --- a/engine/helpers.go +++ b/engine/helpers.go @@ -107,7 +107,7 @@ func (bot *Engine) GetRPCEndpoints() (map[string]RPCEndpoint, error) { }, grpcProxyName: { Started: bot.Settings.EnableGRPCProxy, - ListenAddr: "http://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress, + ListenAddr: "https://" + bot.Config.RemoteControl.GRPC.GRPCProxyListenAddress, }, DeprecatedName: { Started: bot.Settings.EnableDeprecatedRPC, diff --git a/engine/rpcserver.go b/engine/rpcserver.go index a2209987..41d81669 100644 --- a/engine/rpcserver.go +++ b/engine/rpcserver.go @@ -172,12 +172,14 @@ func StartRPCServer(engine *Engine) { // StartRPCRESTProxy starts a gRPC proxy func (s *RPCServer) StartRPCRESTProxy() { - log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on http://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress) + log.Debugf(log.GRPCSys, "gRPC proxy server support enabled. Starting gRPC proxy server on https://%v.\n", s.Config.RemoteControl.GRPC.GRPCProxyListenAddress) targetDir := utils.GetTLSDir(s.Settings.DataDir) - creds, err := credentials.NewClientTLSFromFile(filepath.Join(targetDir, "cert.pem"), "") + certFile := filepath.Join(targetDir, "cert.pem") + keyFile := filepath.Join(targetDir, "key.pem") + creds, err := credentials.NewClientTLSFromFile(certFile, "") if err != nil { - log.Errorf(log.GRPCSys, "Unabled to start gRPC proxy. Err: %s\n", err) + log.Errorf(log.GRPCSys, "Unable to start gRPC proxy. Err: %s\n", err) return } @@ -200,16 +202,31 @@ func (s *RPCServer) StartRPCRESTProxy() { Addr: s.Config.RemoteControl.GRPC.GRPCProxyListenAddress, ReadHeaderTimeout: time.Minute, ReadTimeout: time.Minute, + Handler: s.authClient(mux), } - if err = server.ListenAndServe(); err != nil { - log.Errorf(log.GRPCSys, "GRPC proxy failed to server: %s\n", err) + if err = server.ListenAndServeTLS(certFile, keyFile); err != nil { + log.Errorf(log.GRPCSys, "gRPC proxy server failed to serve: %s\n", err) + return } }() log.Debugln(log.GRPCSys, "gRPC proxy server started!") } +func (s *RPCServer) authClient(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username != s.Config.RemoteControl.Username || password != s.Config.RemoteControl.Password { + w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`) + http.Error(w, "Access denied", http.StatusUnauthorized) + log.Warnf(log.GRPCSys, "gRPC proxy server unauthorised access attempt. IP: %s Path: %s\n", r.RemoteAddr, r.URL.Path) + return + } + handler.ServeHTTP(w, r) + }) +} + // GetInfo returns info about the current GoCryptoTrader session func (s *RPCServer) GetInfo(_ context.Context, _ *gctrpc.GetInfoRequest) (*gctrpc.GetInfoResponse, error) { rpcEndpoints, err := s.getRPCEndpoints() diff --git a/engine/rpcserver_test.go b/engine/rpcserver_test.go index c1f1a08b..84a15b7f 100644 --- a/engine/rpcserver_test.go +++ b/engine/rpcserver_test.go @@ -2,12 +2,20 @@ package engine import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/json" "errors" "fmt" + "io" "log" + "math/rand" + "net/http" + "net/http/httptest" "os" "path/filepath" "reflect" + "strconv" "strings" "sync" "testing" @@ -16,6 +24,7 @@ import ( "github.com/gofrs/uuid" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/common/key" @@ -4139,3 +4148,145 @@ func TestGetOpenInterest(t *testing.T) { _, err = s.GetOpenInterest(context.Background(), req) assert.NoError(t, err) } + +func TestStartRPCRESTProxy(t *testing.T) { + t.Parallel() + + tempDir := filepath.Join(os.TempDir(), "gct-grpc-proxy-test") + tempDirTLS := filepath.Join(tempDir, "tls") + + t.Cleanup(func() { + assert.NoErrorf(t, os.RemoveAll(tempDir), "RemoveAll should not error, manual directory deletion required for TempDir: %s", tempDir) + }) + + if !assert.NoError(t, genCert(tempDirTLS), "genCert should not error") { + t.FailNow() + } + + gRPCPort := rand.Intn(65535-42069) + 42069 //nolint:gosec // Don't require crypto/rand usage here + gRPCProxyPort := gRPCPort + 1 + + e := &Engine{ + Config: &config.Config{ + RemoteControl: config.RemoteControlConfig{ + Username: "bobmarley", + Password: "Sup3rdup3rS3cr3t", + GRPC: config.GRPCConfig{ + Enabled: true, + ListenAddress: "localhost:" + strconv.Itoa(gRPCPort), + GRPCProxyListenAddress: "localhost:" + strconv.Itoa(gRPCProxyPort), + }, + }, + }, + Settings: Settings{ + DataDir: tempDir, + CoreSettings: CoreSettings{EnableGRPCProxy: true}, + }, + } + + fakeTime := time.Now().Add(-time.Hour) + e.uptime = fakeTime + + StartRPCServer(e) + + // Give the proxy time to start + time.Sleep(time.Millisecond * 500) + + certFile := filepath.Join(tempDirTLS, "cert.pem") + caCert, err := os.ReadFile(certFile) + require.NoError(t, err, "ReadFile should not error") + caCertPool := x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM(caCert) + require.True(t, ok, "AppendCertsFromPEM should return true") + client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}}} + + for _, creds := range []struct { + testDescription string + username string + password string + }{ + {"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"}, + {"Valid username but invalid password", "bobmarley", "wrongpass"}, + {"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"}, + {"Invalid username and password despite glorious credentials", "bonk", "wif"}, + } { + creds := creds + t.Run(creds.testDescription, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://localhost:"+strconv.Itoa(gRPCProxyPort)+"/v1/getinfo", http.NoBody) + require.NoError(t, err, "NewRequestWithContext should not error") + req.SetBasicAuth(creds.username, creds.password) + resp, err := client.Do(req) + require.NoError(t, err, "Do should not error") + defer resp.Body.Close() + + if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" { + var info gctrpc.GetInfoResponse + err = json.NewDecoder(resp.Body).Decode(&info) + require.NoError(t, err, "Decode should not error") + + uptimeDuration, err := time.ParseDuration(info.Uptime) + require.NoError(t, err, "ParseDuration should not error") + assert.InDelta(t, time.Since(fakeTime).Seconds(), uptimeDuration.Seconds(), 1.0, "Uptime should be within 1 second of the expected duration") + } else { + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err, "ReadAll should not error") + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "HTTP status code should be 401") + assert.Equal(t, "Access denied\n", string(respBody), "Response body should be 'Access denied\n'") + } + }) + } +} + +func TestRPCProxyAuthClient(t *testing.T) { + t.Parallel() + + s := new(RPCServer) + s.Engine = &Engine{ + Config: &config.Config{ + RemoteControl: config.RemoteControlConfig{ + Username: "bobmarley", + Password: "Sup3rdup3rS3cr3t", + }, + }, + } + + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("MEOW")) + require.NoError(t, err, "Write should not error") + }) + + handler := s.authClient(dummyHandler) + + for _, creds := range []struct { + testDescription string + username string + password string + }{ + {"Valid credentials", "bobmarley", "Sup3rdup3rS3cr3t"}, + {"Valid username but invalid password", "bobmarley", "wrongpass"}, + {"Invalid username but valid password", "bonk", "Sup3rdup3rS3cr3t"}, + {"Invalid username and password despite glorious credentials", "bonk", "wif"}, + } { + creds := creds + t.Run(creds.testDescription, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", http.NoBody) + require.NoError(t, err, "NewRequestWithContext should not error") + req.SetBasicAuth(creds.username, creds.password) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if creds.username == "bobmarley" && creds.password == "Sup3rdup3rS3cr3t" { + assert.Equal(t, http.StatusOK, rr.Code, "HTTP status code should be 200") + assert.Equal(t, "MEOW", rr.Body.String(), "Response body should be 'MEOW'") + } else { + assert.Equal(t, http.StatusUnauthorized, rr.Code, "HTTP status code should be 401") + assert.Equal(t, "Access denied\n", rr.Body.String(), "Response body should be 'Access denied\n'") + } + }) + } +}