diff --git a/README.md b/README.md index 6ed0bda..c9c2447 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,15 @@ For other platforms or to build from source, clone the repository and just run ` Usage of dist/tftp-http-proxy: -http-base-url string - HTTP base URL (default "http://127.0.0.1/tftp") + HTTP base URL (default "http://127.0.0.1/tftp") + -http-append-path + Append path to url (add slash and then requested path), warning + no path sanitization is done, so it may contain /../ or other + wild characters. -tftp-timeout duration - TFTP timeout (default 5s) + TFTP timeout (default 5s) + -tftp-bind-address address + UDP address to bind to ## Details diff --git a/main.go b/main.go index 4fcc9cf..5ae9df7 100644 --- a/main.go +++ b/main.go @@ -27,25 +27,6 @@ var globalState = struct { appendPath: appendPathDefault, } -func urlJoin(base string, other string) (string, error) { - if !strings.HasSuffix(base, "/") { - base = base + "/" - } - - b, err := url.Parse(base) - if err != nil { - return "", err - } - - o, err := url.Parse(strings.TrimPrefix(other, "/")) - if err != nil { - return "", err - } - - u := b.ResolveReference(o) - return u.String(), nil -} - func tftpReadHandler(filename string, rf io.ReaderFrom) error { raddr := rf.(tftp.OutgoingTransfer).RemoteAddr() // net.UDPAddr @@ -53,12 +34,12 @@ func tftpReadHandler(filename string, rf io.ReaderFrom) error { uri := globalState.httpBaseUrl if globalState.appendPath { - var err error - uri, err = urlJoin(uri, filename) - if err != nil { - log.Printf("ERR: error building URL: %v", err) - return err - } + // No need to validate url any further, http.NewRequest does + // this for us using url.Parse(). We already checked that base + // contains scheme and host and ends with a slash. We assume + // that appending filename does not change scheme, host and initial + // part of path of URL. + uri = uri + strings.TrimLeft(filename, "/") } req, err := http.NewRequest("GET", uri, nil) @@ -99,6 +80,25 @@ func tftpReadHandler(filename string, rf io.ReaderFrom) error { return nil } +func parseBaseURL(baseUrl string, appendPath bool) string { + u, err := url.ParseRequestURI(baseUrl) + if err != nil { + log.Panicf("FATAL: invalid base URL: %v\n", err) + } + if (u.Scheme == "") { + log.Panicf("FATAL: invalid base URL: No scheme found.\n") + } + if (u.Host == "") { + log.Panicf("FATAL: invalid base URL: No host found.\n") + } + base := u.String() + if appendPath && !strings.HasSuffix(base, "/") { + return base + "/" + } else { + return base + } +} + func main() { httpBaseUrlPtr := flag.String("http-base-url", httpBaseUrlDefault, "HTTP base URL") appendPathPtr := flag.Bool("http-append-path", appendPathDefault, "append TFTP filename to URL") @@ -107,7 +107,7 @@ func main() { flag.Parse() - globalState.httpBaseUrl = *httpBaseUrlPtr + globalState.httpBaseUrl = parseBaseURL(*httpBaseUrlPtr, *appendPathPtr) globalState.httpClient = &http.Client{} globalState.appendPath = *appendPathPtr