Skip to content

Add custom dns to resolve host #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ var (
httpHead bool
httpPost bool
httpUA string

dnsServer []string
)

var rootCmd = cobra.Command{
Expand Down Expand Up @@ -110,6 +112,9 @@ var rootCmd = cobra.Command{
return
}
}
if len(dnsServer) != 0 {
ping.UseCustomeDNS(dnsServer)
}
target := ping.Target{
Timeout: timeoutDuration,
Interval: intervalDuration,
Expand Down Expand Up @@ -163,6 +168,8 @@ func init() {
rootCmd.Flags().BoolVar(&httpPost, "post", false, `Use HEAD instead of GET in http mode.`)
rootCmd.Flags().StringVar(&httpUA, "user-agent", "tcping", `Use custom UA in http mode.`)

rootCmd.Flags().StringArrayVarP(&dnsServer, "dns-server", "D", nil, `Use the specified dns resolve server.`)

}

func main() {
Expand Down
22 changes: 15 additions & 7 deletions ping/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"time"
)

Expand Down Expand Up @@ -47,15 +49,15 @@ func (ping *HTTPing) Start() <-chan struct{} {
ping.Stop()
return
}
duration, resp, err := ping.ping()
duration, resp, remoteAddr, err := ping.ping()
ping.result.Counter++

if err != nil {
fmt.Printf("Ping %s - failed: %s\n", ping.target, err)
} else {
defer resp.Body.Close()
length, _ := io.Copy(ioutil.Discard, resp.Body)
fmt.Printf("Ping %s - %s is open - time=%s method=%s status=%d bytes=%d\n", ping.target, ping.target.Protocol, duration, ping.Method, resp.StatusCode, length)
fmt.Printf("Ping %s(%s) - %s is open - time=%s method=%s status=%d bytes=%d\n", ping.target, remoteAddr, ping.target.Protocol, duration, ping.Method, resp.StatusCode, length)
if ping.result.MinDuration == 0 {
ping.result.MinDuration = duration
}
Expand Down Expand Up @@ -88,7 +90,7 @@ func (ping *HTTPing) Stop() {
ping.done <- struct{}{}
}

func (ping HTTPing) ping() (time.Duration, *http.Response, error) {
func (ping HTTPing) ping() (time.Duration, *http.Response, net.Addr, error) {
var resp *http.Response
var body io.Reader
if ping.Method == "POST" {
Expand All @@ -97,17 +99,23 @@ func (ping HTTPing) ping() (time.Duration, *http.Response, error) {
req, err := http.NewRequest(ping.Method, ping.target.String(), body)
req.Header.Set(http.CanonicalHeaderKey("User-Agent"), "tcping")
if err != nil {
return 0, nil, err
return 0, nil, nil, err
}

var remoteAddr net.Addr
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
remoteAddr = connInfo.Conn.RemoteAddr()
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
duration, errIfce := timeIt(func() interface{} {
client := http.Client{Timeout: ping.target.Timeout}
resp, err = client.Do(req)
return err
})
if errIfce != nil {
err := errIfce.(error)
return 0, nil, err
return 0, nil, nil, err
}
return time.Duration(duration), resp, nil
return time.Duration(duration), resp, remoteAddr, nil
}
13 changes: 8 additions & 5 deletions ping/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,21 @@ func (tcping TCPing) Result() *Result {
func (tcping TCPing) Start() <-chan struct{} {
go func() {
t := time.NewTicker(tcping.target.Interval)
defer t.Stop()
for {
select {
case <-t.C:
if tcping.result.Counter >= tcping.target.Counter && tcping.target.Counter != 0 {
tcping.Stop()
return
}
duration, err := tcping.ping()
duration, remoteAddr, err := tcping.ping()
tcping.result.Counter++

if err != nil {
fmt.Printf("Ping %s - failed: %s\n", tcping.target, err)
} else {
fmt.Printf("Ping %s - Connected - time=%s\n", tcping.target, duration)
fmt.Printf("Ping %s(%s) - Connected - time=%s\n", tcping.target, remoteAddr, duration)

if tcping.result.MinDuration == 0 {
tcping.result.MinDuration = duration
Expand Down Expand Up @@ -82,18 +83,20 @@ func (tcping *TCPing) Stop() {
tcping.done <- struct{}{}
}

func (tcping TCPing) ping() (time.Duration, error) {
func (tcping TCPing) ping() (time.Duration, net.Addr, error) {
var remoteAddr net.Addr
duration, errIfce := timeIt(func() interface{} {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", tcping.target.Host, tcping.target.Port), tcping.target.Timeout)
if err != nil {
return err
}
remoteAddr = conn.RemoteAddr()
conn.Close()
return nil
})
if errIfce != nil {
err := errIfce.(error)
return 0, err
return 0, remoteAddr, err
}
return time.Duration(duration), nil
return time.Duration(duration), remoteAddr, nil
}
24 changes: 23 additions & 1 deletion ping/utils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
package ping

import "time"
import (
"context"
"net"
"time"
)

func timeIt(f func() interface{}) (int64, interface{}) {
startAt := time.Now()
res := f()
endAt := time.Now()
return endAt.UnixNano() - startAt.UnixNano(), res
}

// UseCustomeDNS will set the dns to default DNS resolver for global
func UseCustomeDNS(dns []string) {
resolver := net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
for _, addr := range dns {
if conn, err = net.Dial("udp", addr+":53"); err != nil {
continue
} else {
return conn, nil
}
}
return
},
}
net.DefaultResolver = &resolver
}