diff --git a/restful_router.go b/restful_router.go index 14b94c11..cbb1f00a 100644 --- a/restful_router.go +++ b/restful_router.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gorilla/mux" + "github.com/thrasher-/gocryptotrader/common" log "github.com/thrasher-/gocryptotrader/logger" _ "net/http/pprof" @@ -45,6 +46,7 @@ var routes = Routes{} // router func NewRouter() *mux.Router { router := mux.NewRouter().StrictSlash(true) + listenAddr := bot.config.Webserver.ListenAddress routes = Routes{ Route{ @@ -114,7 +116,8 @@ func NewRouter() *mux.Router { Methods(route.Method). Path(route.Pattern). Name(route.Name). - Handler(RESTLogger(route.HandlerFunc, route.Name)) + Handler(RESTLogger(route.HandlerFunc, route.Name)). + Host(common.ExtractHost(listenAddr)) } if bot.config.Profiler.Enabled { diff --git a/restful_server_test.go b/restful_server_test.go index 80cf3bd8..88a450c4 100644 --- a/restful_server_test.go +++ b/restful_server_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/thrasher-/gocryptotrader/common" "github.com/thrasher-/gocryptotrader/config" ) @@ -51,3 +52,33 @@ func TestConfigAllJsonResponse(t *testing.T) { t.Error("Test failed. Json not equal to config") } } + +func TestInvalidHostRequest(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/config/all", nil) + if err != nil { + t.Fatal(err) + } + req.Host = "invalidsite.com" + + resp := httptest.NewRecorder() + NewRouter().ServeHTTP(resp, req) + + if status := resp.Code; status != http.StatusNotFound { + t.Errorf("Test failed. Response returned wrong status code expected %v got %v", http.StatusNotFound, status) + } +} + +func TestValidHostRequest(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/config/all", nil) + if err != nil { + t.Fatal(err) + } + req.Host = common.ExtractHost(bot.config.Webserver.ListenAddress) + + resp := httptest.NewRecorder() + NewRouter().ServeHTTP(resp, req) + + if status := resp.Code; status != http.StatusOK { + t.Errorf("Test failed. Response returned wrong status code expected %v got %v", http.StatusOK, status) + } +}