diff --git a/cmd/tlsrouter/config.go b/cmd/tlsrouter/config.go index 1c8151f..acc8ac3 100644 --- a/cmd/tlsrouter/config.go +++ b/cmd/tlsrouter/config.go @@ -69,8 +69,10 @@ func (c *Config) Match(hostname string) (string, bool) { } for _, r := range c.routes { - if r.match.MatchString(hostname) { - return r.backend, r.proxyInfo + matches := r.match.FindStringSubmatchIndex(hostname) + if matches != nil { + result := r.match.ExpandString(nil, r.backend, hostname, matches) + return string(result), r.proxyInfo } } return "", false diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go index ff1a816..5b85542 100644 --- a/cmd/tlsrouter/main.go +++ b/cmd/tlsrouter/main.go @@ -16,21 +16,55 @@ package main import ( "bytes" + "context" + "crypto/tls" "flag" "fmt" "io" "log" "net" + "strings" "sync" "time" ) var ( - cfgFile = flag.String("conf", "", "configuration file") - listen = flag.String("listen", ":443", "listening port") - helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") + cfgFile = flag.String("conf", "", "configuration file") + listen = flag.String("listen", ":443", "listening port") + helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") + helloMss = flag.Int64("hello-mss", 0, "how many bytes to fragment/segment the TLS ClientHello") + resolverAddress = flag.String("dns", "", "address of the dns resolver") + resolverNetwork = flag.String("dns-net", "", "protocol for the dns resolver (e.g. \"tcp-tls\" or \"tcp\" or \"udp\")") ) +// BackendDialer with timeout +var BackendDialer = &net.Dialer{ + Timeout: 15 * time.Second, + Resolver: &net.Resolver{ + PreferGo: true, + Dial: dialDNSResolver, + }, +} + +// ResolverDialer with timeout +var ResolverDialer = &net.Dialer{ + Timeout: 10 * time.Second, +} + +// ResolverTLSConfig for DNS-over-TLS +var ResolverTLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP521, tls.CurveP384, tls.CurveP256}, + CipherSuites: []uint16{ + // tls.TLS_CHACHA20_POLY1305_SHA256, + // tls.TLS_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }, +} + func main() { flag.Parse() @@ -142,7 +176,7 @@ func (c *Conn) proxy() { } c.logf("routing %q to %q", c.hostname, c.backend) - backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) + backend, err := BackendDialer.Dial("tcp", c.backend) if err != nil { c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) return @@ -168,7 +202,20 @@ func (c *Conn) proxy() { // Replay the piece of the handshake we had to read to do the // routing, then blindly proxy any other bytes. - if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { + if *helloMss == 0 { + _, err = io.Copy(c.backendConn, &handshakeBuf) + } else { + for { + _, err = io.CopyN(c.backendConn, &handshakeBuf, *helloMss) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + } + if err != nil { c.internalError("failed to replay handshake to %q: %s", c.backend, err) return } @@ -189,3 +236,26 @@ func proxy(wg *sync.WaitGroup, a, b net.Conn) { btcp.CloseWrite() atcp.CloseRead() } + +func dialDNSResolver(ctx context.Context, network, address string) (net.Conn, error) { + if *resolverNetwork != "" { + network = *resolverNetwork + } + if *resolverAddress != "" { + address = *resolverAddress + } + + useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls") + if useTLS { + network = strings.TrimSuffix(network, "-tls") + if !strings.Contains(address, ":") { + address += ":853" + } + return tls.DialWithDialer(ResolverDialer, network, address, ResolverTLSConfig) + } + + if !strings.Contains(address, ":") { + address += ":53" + } + return ResolverDialer.DialContext(ctx, network, address) +}