diff --git a/main.go b/main.go index 92d5ef5..f076d3d 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,8 @@ import ( ) const ( - errIPv4Only = "Navigator works for valid IPv4 only :)" + errIPv4Only = "Navigator works for valid IPv4 only :)" + remoteAddrHeader = "X-NAV-REMOTE-IP" ) type errorMessage struct { @@ -46,6 +47,23 @@ func responseWithJsonError(resp http.ResponseWriter, statusCode int, message str } } +func getRemoteIP(req *http.Request) string { + if addr := req.Header.Get(remoteAddrHeader); addr != "" { + if net.ParseIP(addr).To4() == nil { + return "" + } + return addr + } + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + return "" + } + if net.ParseIP(host).To4() == nil { + return "" + } + return host +} + func buildLocation(info *ipdb.CityInfo) string { ret := "" if info.CountryName != "" { @@ -74,17 +92,12 @@ func main() { var host string if argIp := req.FormValue("ip"); argIp != "" { host = argIp - } else { - ip, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { + if net.ParseIP(host).To4() == nil { responseWithError(resp, http.StatusPreconditionFailed, errIPv4Only) return } - host = ip - } - if net.ParseIP(host).To4() == nil { - responseWithError(resp, http.StatusPreconditionFailed, errIPv4Only) - return + } else { + host = getRemoteIP(req) } db := ipgeo.Get() @@ -120,15 +133,7 @@ func main() { }) http.HandleFunc("/mapping", func(resp http.ResponseWriter, req *http.Request) { - host, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - responseWithError(resp, http.StatusPreconditionFailed, errIPv4Only) - return - } - if net.ParseIP(host).To4() == nil { - responseWithError(resp, http.StatusPreconditionFailed, errIPv4Only) - return - } + host := getRemoteIP(req) resp.Header().Set("Content-Type", "text/plain") resp.WriteHeader(http.StatusOK) server := mapping.Get(host) @@ -151,15 +156,7 @@ func main() { clientV1Api := http.NewServeMux() clientApi.Handle("/v1/", http.StripPrefix("/v1", clientV1Api)) clientV1Api.HandleFunc("/getNodes", func(resp http.ResponseWriter, req *http.Request) { - host, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - responseWithJsonError(resp, http.StatusPreconditionFailed, errIPv4Only) - return - } - if net.ParseIP(host).To4() == nil { - responseWithJsonError(resp, http.StatusPreconditionFailed, errIPv4Only) - return - } + host := getRemoteIP(req) nodes := mapping.GetNodes() if nodes == nil { responseWithJsonError(resp, http.StatusInternalServerError, "Unable to get nodes") @@ -174,9 +171,6 @@ func main() { "suffix": suffix, }) }) - clientV1Api.HandleFunc("/getSuffix", func(resp http.ResponseWriter, req *http.Request) { - }) - log.Println("HTTP server is running on", *listen_spec) http.ListenAndServe(*listen_spec, nil) }