diff --git a/client.go b/client.go index 6443024..ea118fd 100755 --- a/client.go +++ b/client.go @@ -70,7 +70,7 @@ func (s *TRPClient) process(c *Conn) error { return err } case WORK_CHAN: //隧道模式,每次开启10个,加快连接速度 - for i := 0; i < 10; i++ { + for i := 0; i < 100; i++ { go s.dealChan() } case RES_MSG: @@ -86,6 +86,9 @@ func (s *TRPClient) process(c *Conn) error { func (s *TRPClient) dealChan() error { //创建一个tcp连接 conn, err := net.Dial("tcp", s.svrAddr) + if err != nil { + return err + } //验证 if _, err := conn.Write(getverifyval()); err != nil { return err @@ -95,36 +98,31 @@ func (s *TRPClient) dealChan() error { c.SetAlive() //写标志 c.wChan() - //获取连接的host - host, err := c.GetHostFromConn() + //获取连接的host type(tcp or udp) + typeStr, host, err := c.GetHostFromConn() if err != nil { return err } //与目标建立连接 - server, err := net.Dial("tcp", host) + server, err := net.Dial(typeStr, host) if err != nil { - fmt.Println(err) + log.Println(err) return err } - //创建成功后io.copy - go relay(server, c.conn) - relay(c.conn, server) + go relay(NewConn(server), c, DataDecode) + relay(c, NewConn(server), DataEncode) return nil } //http模式处理 func (s *TRPClient) dealHttp(c *Conn) error { - nlen, err := c.GetLen() + buf := make([]byte, 1024*32) + n, err := c.ReadFromCompress(buf, DataDecode) if err != nil { c.wError() return err } - raw, err := c.ReadLen(int(nlen)) - if err != nil { - c.wError() - return err - } - req, err := DecodeRequest(raw) + req, err := DecodeRequest(buf[:n]) if err != nil { c.wError() return err @@ -134,7 +132,8 @@ func (s *TRPClient) dealHttp(c *Conn) error { c.wError() return err } - n, err := c.Write(respBytes) + c.wSign() + n, err = c.WriteCompress(respBytes, DataEncode) if err != nil { return err } diff --git a/conn.go b/conn.go index 41df325..7df1083 100644 --- a/conn.go +++ b/conn.go @@ -2,10 +2,13 @@ package main import ( "bytes" + "compress/gzip" "encoding/binary" "errors" "fmt" + "github.com/golang/snappy" "io" + "log" "net" "net/url" "regexp" @@ -47,7 +50,7 @@ func (s *Conn) ReadLen(len int) ([]byte, error) { //获取长度 func (s *Conn) GetLen() (int, error) { val := make([]byte, 4) - _, err := s.conn.Read(val) + _, err := s.Read(val) if err != nil { return 0, err } @@ -58,6 +61,21 @@ func (s *Conn) GetLen() (int, error) { return int(nlen), nil } +//写入长度 +func (s *Conn) WriteLen(buf []byte) (int, error) { + raw := bytes.NewBuffer([]byte{}) + + if err := binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil { + log.Println(err) + return 0, err + } + if err = binary.Write(raw, binary.LittleEndian, buf); err != nil { + log.Println(err) + return 0, err + } + return s.Write(raw.Bytes()) +} + //读取flag func (s *Conn) ReadFlag() (string, error) { val := make([]byte, 4) @@ -69,22 +87,30 @@ func (s *Conn) ReadFlag() (string, error) { } //读取host -func (s *Conn) GetHostFromConn() (string, error) { +func (s *Conn) GetHostFromConn() (typeStr string, host string, err error) { + ltype := make([]byte, 3) + _, err = s.Read(ltype) + if err != nil { + return + } + typeStr = string(ltype) len, err := s.GetLen() if err != nil { - return "", err + return } hostByte := make([]byte, len) _, err = s.conn.Read(hostByte) if err != nil { - return "", err + return } - return string(hostByte), nil + host = string(hostByte) + return } -//获取host -func (s *Conn) WriteHost(host string) (int, error) { +//写tcp host +func (s *Conn) WriteHost(ltype string, host string) (int, error) { raw := bytes.NewBuffer([]byte{}) + binary.Write(raw, binary.LittleEndian, []byte(ltype)) binary.Write(raw, binary.LittleEndian, int32(len([]byte(host)))) binary.Write(raw, binary.LittleEndian, []byte(host)) return s.Write(raw.Bytes()) @@ -139,10 +165,47 @@ func (s *Conn) Write(b []byte) (int, error) { func (s *Conn) Read(b []byte) (int, error) { return s.conn.Read(b) } +func (s *Conn) ReadFromCompress(b []byte, compress int) (int, error) { + switch compress { + case COMPRESS_GZIP_DECODE: + r, err := gzip.NewReader(s) + if err != nil { + return 0, err + } + return r.Read(b) + case COMPRESS_SNAPY_DECODE: + r := snappy.NewReader(s) + return r.Read(b) + case COMPRESS_NONE: + return s.Read(b) + } + return 0, nil +} + +func (s *Conn) WriteCompress(b []byte, compress int) (n int, err error) { + switch compress { + case COMPRESS_GZIP_ENCODE: + w := gzip.NewWriter(s) + if n, err = w.Write(b); err == nil { + w.Flush() + } + case COMPRESS_SNAPY_ENCODE: + w := snappy.NewBufferedWriter(s) + if n, err = w.Write(b); err == nil { + w.Flush() + } + case COMPRESS_NONE: + n, err = s.Write(b) + } + return +} func (s *Conn) wError() { s.conn.Write([]byte(RES_MSG)) } +func (s *Conn) wSign() { + s.conn.Write([]byte(RES_SIGN)) +} func (s *Conn) wMain() { s.conn.Write([]byte(WORK_MAIN)) diff --git a/main.go b/main.go index 905ca8b..5d96cf8 100755 --- a/main.go +++ b/main.go @@ -11,15 +11,33 @@ var ( tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口") httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口") rpMode = flag.String("mode", "client", "启动模式,可选为client、server") - tunnelTarget = flag.String("target", "10.1.50.203:80", "tunnel模式远程目标") + tunnelTarget = flag.String("target", "10.1.50.203:80", "远程目标") verifyKey = flag.String("vkey", "", "验证密钥") + u = flag.String("u", "", "sock5验证用户名") + p = flag.String("p", "", "sock5验证密码") + compress = flag.String("compress", "", "数据压缩(gizp|snappy)") config Config err error + DataEncode int + DataDecode int ) func main() { flag.Parse() log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) + switch *compress { + case "": + DataDecode = COMPRESS_NONE + DataEncode = COMPRESS_NONE + case "gzip": + DataDecode = COMPRESS_GZIP_DECODE + DataEncode = COMPRESS_GZIP_ENCODE + case "snnapy": + DataDecode = COMPRESS_SNAPY_DECODE + DataEncode = COMPRESS_SNAPY_ENCODE + default: + log.Fatalln("数据压缩格式错误") + } if *rpMode == "client" { JsonParse := NewJsonStruct() config, err = JsonParse.Load(*configPath) @@ -50,11 +68,14 @@ func main() { svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessTunnel) svr.Start() } else if *rpMode == "sock5Server" { - svr := NewSock5ModeServer(*tcpPort, *httpPort) + svr := NewSock5ModeServer(*tcpPort, *httpPort, *u, *p) svr.Start() } else if *rpMode == "httpProxyServer" { svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessHttp) svr.Start() + } else if *rpMode == "udpServer" { + svr := NewUdpModeServer(*tcpPort, *httpPort, *tunnelTarget) + svr.Start() } } } diff --git a/server.go b/server.go index c4b5f19..c81ed3f 100755 --- a/server.go +++ b/server.go @@ -74,7 +74,8 @@ func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error { if err != nil { return err } - c, err := conn.Write(raw) + conn.wSign() + c, err := conn.WriteCompress(raw, DataEncode) if err != nil { return err } @@ -92,15 +93,12 @@ func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error { } switch flags { case RES_SIGN: - nlen, err := c.GetLen() + buf := make([]byte, 1024*32) + n, err := c.ReadFromCompress(buf, DataDecode) if err != nil { return err } - raw, err := c.ReadLen(nlen) - if err != nil { - return err - } - resp, err := DecodeResponse(raw) + resp, err := DecodeResponse(buf[:n]) if err != nil { return err } @@ -176,12 +174,12 @@ func (s *TunnelModeServer) startTunnelServer() { func ProcessTunnel(c *Conn, s *TunnelModeServer) error { retry: link := s.GetTunnel() - if _, err := link.WriteHost(s.tunnelTarget); err != nil { + if _, err := link.WriteHost("tcp", s.tunnelTarget); err != nil { link.Close() goto retry } - go relay(link.conn, c.conn) - relay(c.conn, link.conn) + go relay(link, c, DataEncode) + relay(c, link, DataDecode) return nil } @@ -194,16 +192,16 @@ func ProcessHttp(c *Conn, s *TunnelModeServer) error { } retry: link := s.GetTunnel() - if _, err := link.WriteHost(addr); err != nil { + if _, err := link.WriteHost("tcp", addr); err != nil { link.Close() goto retry } if method == "CONNECT" { fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n") } else { - link.Write(rb) + link.WriteCompress(rb, DataEncode) } - go relay(link.conn, c.conn) - relay(c.conn, link.conn) + go relay(link, c, DataEncode) + relay(c, link, DataDecode) return nil } diff --git a/sock5.go b/sock5.go index e6cd575..400bd77 100644 --- a/sock5.go +++ b/sock5.go @@ -10,9 +10,9 @@ import ( ) const ( - ipV4 = 1 - domainName = 3 - ipV6 = 4 + ipV4 = 1 + domainName = 3 + ipV6 = 4 connectMethod = 1 bindMethod = 2 associateMethod = 3 @@ -35,9 +35,19 @@ const ( addrTypeNotSupported ) +const ( + UserPassAuth = uint8(2) + userAuthVersion = uint8(1) + authSuccess = uint8(0) + authFailure = uint8(1) +) + type Sock5ModeServer struct { Tunnel httpPort int + u string //用户名 + p string //密码 + isVerify bool } func (s *Sock5ModeServer) handleRequest(c net.Conn) { @@ -119,37 +129,31 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) (proxyConn *Conn, var port uint16 binary.Read(c, binary.BigEndian, &port) - // connect to host addr := net.JoinHostPort(host, strconv.Itoa(int(port))) - //取出一个连接 - if len(s.tunnelList) < 10 { //新建通道 - go s.newChan() - } - client := <-s.tunnelList + client := s.GetTunnel() s.sendReply(c, succeeded) - _, err = client.WriteHost(addr) + var ltype string + if command == associateMethod { + ltype = "udp" + } else { + ltype = "tcp" + } + _, err = client.WriteHost(ltype, addr) return client, nil } func (s *Sock5ModeServer) handleConnect(c net.Conn) { proxyConn, err := s.doConnect(c, connectMethod) if err != nil { + log.Println(err) c.Close() } else { - go io.Copy(c, proxyConn) - go io.Copy(proxyConn, c) + go relay(proxyConn, NewConn(c), DataEncode) + go relay(NewConn(c), proxyConn, DataDecode) } } - -func (s *Sock5ModeServer) relay(in, out net.Conn) { - if _, err := io.Copy(in, out); err != nil { - log.Println("copy error", err) - } - in.Close() // will trigger an error in the other relay, then call out.Close() -} - // passive mode func (s *Sock5ModeServer) handleBind(c net.Conn) { } @@ -177,8 +181,8 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) { if err != nil { c.Close() } else { - go io.Copy(c, proxyConn) - go io.Copy(proxyConn, c) + go relay(proxyConn, NewConn(c), DataEncode) + go relay(NewConn(c), proxyConn, DataDecode) } } @@ -203,14 +207,56 @@ func (s *Sock5ModeServer) handleNewConn(c net.Conn) { c.Close() return } - // no authentication required for now - buf[1] = 0 - // send a METHOD selection message - c.Write(buf) - + if s.isVerify { + buf[1] = UserPassAuth + c.Write(buf) + if err := s.Auth(c); err != nil { + c.Close() + log.Println("验证失败:", err) + return + } + } else { + buf[1] = 0 + c.Write(buf) + } s.handleRequest(c) } +func (s *Sock5ModeServer) Auth(c net.Conn) error { + header := []byte{0, 0} + if _, err := io.ReadAtLeast(c, header, 2); err != nil { + return err + } + if header[0] != userAuthVersion { + return errors.New("验证方式不被支持") + } + userLen := int(header[1]) + user := make([]byte, userLen) + if _, err := io.ReadAtLeast(c, user, userLen); err != nil { + return err + } + if _, err := c.Read(header[:1]); err != nil { + return errors.New("密码长度获取错误") + } + passLen := int(header[0]) + pass := make([]byte, passLen) + if _, err := io.ReadAtLeast(c, pass, passLen); err != nil { + return err + } + if string(pass) == s.p && string(user) == s.u { + if _, err := c.Write([]byte{userAuthVersion, authSuccess}); err != nil { + return err + } + return nil + } else { + if _, err := c.Write([]byte{userAuthVersion, authFailure}); err != nil { + return err + } + return errors.New("验证不通过") + } + return errors.New("未知错误") +} + func (s *Sock5ModeServer) Start() { l, err := net.Listen("tcp", ":"+strconv.Itoa(s.httpPort)) if err != nil { @@ -226,11 +272,18 @@ func (s *Sock5ModeServer) Start() { } } -func NewSock5ModeServer(tcpPort, httpPort int) *Sock5ModeServer { +func NewSock5ModeServer(tcpPort, httpPort int, u, p string) *Sock5ModeServer { s := new(Sock5ModeServer) s.tunnelPort = tcpPort s.httpPort = httpPort s.tunnelList = make(chan *Conn, 1000) s.signalList = make(chan *Conn, 10) + if u != "" && p != "" { + s.isVerify = true + s.u = u + s.p = p + } else { + s.isVerify = false + } return s } diff --git a/udp.go b/udp.go new file mode 100755 index 0000000..467ca97 --- /dev/null +++ b/udp.go @@ -0,0 +1,80 @@ +package main + +import ( + "io" + "log" + "net" + "time" +) + +type UdpModeServer struct { + Tunnel + udpPort int //监听的udp端口 + tunnelTarget string //udp目标地址 + listener *net.UDPConn + udpMap map[string]*Conn +} + +func NewUdpModeServer(tcpPort, udpPort int, tunnelTarget string) *UdpModeServer { + s := new(UdpModeServer) + s.tunnelPort = tcpPort + s.udpPort = udpPort + s.tunnelTarget = tunnelTarget + s.tunnelList = make(chan *Conn, 1000) + s.signalList = make(chan *Conn, 10) + s.udpMap = make(map[string]*Conn) + return s +} + +//开始 +func (s *UdpModeServer) Start() (error) { + err := s.StartTunnel() + if err != nil { + log.Fatalln("启动失败!", err) + return err + } + s.startTunnelServer() + return nil +} + +//udp监听 +func (s *UdpModeServer) startTunnelServer() { + s.listener, err = net.ListenUDP("udp", &net.UDPAddr{net.ParseIP("0.0.0.0"), s.udpPort, ""}) + if err != nil { + log.Fatalln(err) + } + data := make([]byte, 1472) //udp数据包大小 + for { + n, addr, err := s.listener.ReadFromUDP(data) + if err != nil { + log.Println(err) + continue + } + go s.process(addr, data[:n]) + } +} + +func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) { + conn := s.GetTunnel() + conn.WriteHost("udp", s.tunnelTarget) + go func() { + for { + buf := make([]byte, 1024) + conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3))) + n, err := conn.ReadFromCompress(buf, DataDecode) + if err != nil || err == io.EOF { + conn.Close() + break + } + _, err = s.listener.WriteToUDP(buf[:n], addr) + if err != nil { + conn.Close() + break + } + } + + }() + if _, err = conn.WriteCompress(data, DataEncode); err != nil { + conn.Close() + } +} diff --git a/util.go b/util.go index 49f4e12..265c085 100644 --- a/util.go +++ b/util.go @@ -7,8 +7,9 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/golang/snappy" "io" - "net" + "log" "net/http" "net/http/httputil" "net/url" @@ -20,6 +21,14 @@ var ( disabledRedirect = errors.New("disabled redirect.") ) +const ( + COMPRESS_NONE = iota + COMPRESS_SNAPY_ENCODE + COMPRESS_SNAPY_DECODE + COMPRESS_GZIP_ENCODE + COMPRESS_GZIP_DECODE +) + func BadRequest(w http.ResponseWriter) { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) } @@ -46,30 +55,20 @@ func GetEncodeResponse(req *http.Request) ([]byte, error) { return respBytes, nil } -// 将request 的处理 +// 将request转为bytes func EncodeRequest(r *http.Request) ([]byte, error) { raw := bytes.NewBuffer([]byte{}) - // 写签名 - binary.Write(raw, binary.LittleEndian, []byte("sign")) reqBytes, err := httputil.DumpRequest(r, true) if err != nil { return nil, err } - // 写body数据长度 + 1 - binary.Write(raw, binary.LittleEndian, int32(len(reqBytes)+1)) - // 判断是否为http或者https的标识1字节 binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https")) - if err := binary.Write(raw, binary.LittleEndian, reqBytes); err != nil { - return nil, err - } + binary.Write(raw, binary.LittleEndian, reqBytes) return raw.Bytes(), nil } // 将字节转为request func DecodeRequest(data []byte) (*http.Request, error) { - if len(data) <= 100 { - return nil, errors.New("待解码的字节长度太小") - } req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:]))) if err != nil { return nil, err @@ -84,42 +83,25 @@ func DecodeRequest(data []byte) (*http.Request, error) { scheme = "https" } req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI)) - fmt.Println(req.URL) req.RequestURI = "" return req, nil } //// 将response转为字节 func EncodeResponse(r *http.Response) ([]byte, error) { - raw := bytes.NewBuffer([]byte{}) - binary.Write(raw, binary.LittleEndian, []byte(RES_SIGN)) respBytes, err := httputil.DumpResponse(r, true) - if config.Replace == 1 { - respBytes = replaceHost(respBytes) - } if err != nil { return nil, err } - var buf bytes.Buffer - zw := gzip.NewWriter(&buf) - zw.Write(respBytes) - zw.Close() - binary.Write(raw, binary.LittleEndian, int32(len(buf.Bytes()))) - if err := binary.Write(raw, binary.LittleEndian, buf.Bytes()); err != nil { - fmt.Println(err) - return nil, err + if config.Replace == 1 { + respBytes = replaceHost(respBytes) } - return raw.Bytes(), nil + return respBytes, nil } // 将字节转为response func DecodeResponse(data []byte) (*http.Response, error) { - zr, err := gzip.NewReader(bytes.NewReader(data)) - if err != nil { - return nil, err - } - defer zr.Close() - resp, err := http.ReadResponse(bufio.NewReader(zr), nil) + resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), nil) if err != nil { return nil, err } @@ -144,7 +126,52 @@ func replaceHost(resp []byte) []byte { return []byte(str) } -func relay(in, out net.Conn) { - io.Copy(in, out); - in.Close() +func relay(in, out *Conn, compressType int) { + buf := make([]byte, 32*1024) + switch compressType { + case COMPRESS_GZIP_ENCODE: + w := gzip.NewWriter(in) + for { + n, err := out.Read(buf) + if err != nil || err == io.EOF { + break + } + if _, err = w.Write(buf[:n]); err != nil { + break + } + if err = w.Flush(); err != nil { + log.Println(err) + break + } + } + w.Close() + case COMPRESS_SNAPY_ENCODE: + w := snappy.NewBufferedWriter(in) + for { + n, err := out.Read(buf) + if err != nil || err == io.EOF { + break + } + if _, err = w.Write(buf[:n]); err != nil { + break + } + if err = w.Flush(); err != nil { + log.Println(err) + break + } + } + w.Close() + case COMPRESS_GZIP_DECODE: + r, err := gzip.NewReader(out) + if err != nil { + return + } + io.Copy(in, r) + case COMPRESS_SNAPY_DECODE: + r := snappy.NewReader(out) + io.Copy(in, r) + default: + io.Copy(in, out) + } + out.Close() }