diff --git a/config.json b/config.json new file mode 100644 index 0000000..a8c850c --- /dev/null +++ b/config.json @@ -0,0 +1,14 @@ +[ + { + "addr": "192.168.10.201:22", + "user": "root", + "pass": "root", + "tunnels": [ + { + "remote": "127.0.0.1:6379", + "local": "127.0.0.1:6379" + } + ] + } +] + diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a93b533 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module prot + +go 1.17 + +require golang.org/x/crypto v0.0.0-20220214200702-86341886e292 + +require ( + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect + golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..eaeab1c --- /dev/null +++ b/go.sum @@ -0,0 +1,11 @@ +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/main.go b/main.go new file mode 100644 index 0000000..42724b2 --- /dev/null +++ b/main.go @@ -0,0 +1,44 @@ +package main + +import ( + "encoding/json" + "io/ioutil" + "log" + "os" + "os/signal" + "prot/sshtunnel" + "syscall" +) + +func main() { + var sts []*sshtunnel.Config + p := "config.json" + if len(os.Args) == 2 { + p = os.Args[1] + } + f, err := ioutil.ReadFile(p) + if err != nil { + log.Printf("载入配置文件出错, 错误: %v", err) + os.Exit(-1) + } + err = json.Unmarshal(f, &sts) + if nil != err { + log.Printf("解析配置文件内容出错, 错误: %v", err) + os.Exit(-1) + } + + var tunnels []*sshtunnel.SSHTunnel + for _, st := range sts { + tunnel := sshtunnel.NewSSHTunnel(st) + tunnel.Start() + tunnels = append(tunnels, tunnel) + } + + signalChan := make(chan os.Signal) + signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + <-signalChan + for _, t := range tunnels { + t.Close() + } + os.Exit(0) +} diff --git a/sshtunnel/config.go b/sshtunnel/config.go new file mode 100644 index 0000000..0f413e4 --- /dev/null +++ b/sshtunnel/config.go @@ -0,0 +1,13 @@ +package sshtunnel + +type Tunnel struct { + Remote string `json:"remote"` + Local string `json:"local"` +} + +type Config struct { + Addr string `json:"addr"` + User string `json:"user"` + Pass string `json:"pass,omitempty"` + Tunnels []Tunnel `json:"tunnels,omitempty"` +} diff --git a/sshtunnel/ssh.go b/sshtunnel/ssh.go new file mode 100644 index 0000000..fc37d69 --- /dev/null +++ b/sshtunnel/ssh.go @@ -0,0 +1,148 @@ +package sshtunnel + +import ( + "fmt" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" + "io" + "log" + "math" + "net" + "strings" + "syscall" + "time" +) + +type SSHTunnel struct { + config *Config + client *ssh.Client +} + +func NewSSHTunnel(config *Config) *SSHTunnel { + st := new(SSHTunnel) + st.config = config + return st +} + +func (st *SSHTunnel) Start() { + if len(st.config.Pass) == 0 { + st.setPass() + } + st.initSSHClient() + for _, t := range st.config.Tunnels { + go st.connect(t) + } +} + +func (st *SSHTunnel) Close() { + if nil != st.client { + st.client.Close() + } +} + +func (st *SSHTunnel) GetSSHClient() (*ssh.Client, error) { + if st.client != nil { + return st.client, nil + } + var auth []ssh.AuthMethod + auth = make([]ssh.AuthMethod, 0) + auth = append(auth, ssh.Password(st.config.Pass)) + + sc := &ssh.ClientConfig{ + User: st.config.User, + Auth: auth, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + var err error + st.client, err = ssh.Dial("tcp", st.config.Addr, sc) + if err != nil { + return nil, err + } + log.Printf("连接到服务器成功: %s", st.config.Addr) + return st.client, err +} + +func (st *SSHTunnel) connect(t Tunnel) { + tid := fmt.Sprintf("%s-%s", t.Local, t.Remote) + ll, err := net.Listen("tcp", t.Local) + if err != nil { + log.Printf("隧道[%s]接收开启失败, 错误: %v", tid, err) + return + } + defer func() { + ll.Close() + log.Printf("隧道[%s]接收关闭!", tid) + }() + log.Printf("隧道[%s]接收开启!", tid) + sno := int64(0) + for { + lc, err := ll.Accept() + if err != nil { + log.Printf("隧道[%s]接收连接失败, 错误: %v", tid, err) + return + } + sc, err := st.GetSSHClient() + if err != nil { + log.Printf("隧道[%s]接入服务失败, 错误: %v", tid, err) + lc.Close() + continue + } + rc, err := sc.Dial("tcp", t.Remote) + if err != nil { + log.Printf("隧道[%s]接入获取连接失败, 错误: %v", tid, err) + sc.Close() + lc.Close() + continue + } + if sno >= math.MaxInt64 { + sno = 0 + } + sno += 1 + cid := fmt.Sprintf("%s:%d", tid, sno) + go st.transfer(cid, lc, rc) + } +} + +func (st *SSHTunnel) setPass() { + fmt.Printf("请输入登陆密码[%s@%s]:", st.config.User, st.config.Addr) + bytePassword, _ := terminal.ReadPassword(int(syscall.Stdin)) + st.config.Pass = string(bytePassword) + fmt.Println() +} + +func (st *SSHTunnel) initSSHClient() { + var err error + for { + st.client, err = st.GetSSHClient() + if nil != err { + error := err.Error() + log.Printf("连接到服务器[%s]失败, 错误: %s", st.config.Addr, error) + if strings.Contains(error, "unable to authenticate") { + st.config.Pass = "" + st.setPass() + continue + } + if strings.Contains(error, "i/o timeout") { + log.Printf("连接到服务器[%s]超时!", st.config.Addr) + time.Sleep(2 * time.Second) + continue + } + } + return + } +} + +func (st *SSHTunnel) transfer(cid string, lc net.Conn, rc net.Conn) { + defer rc.Close() + defer lc.Close() + go func() { + defer lc.Close() + defer rc.Close() + io.Copy(rc, lc) + }() + log.Printf("通道[%s]已连接!", cid) + io.Copy(lc, rc) + log.Printf("通道[%s]已断开!", cid) +}