diff --git a/README.md b/README.md index 8b8e046..dd350f0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +<<<<<<< Updated upstream # easyProxy 轻量级、较高性能http代理服务器,主要应用与内网穿透。支持多站点配置、客户端与服务端连接中断自动重连,多路传输,大大的提高请求处理速度,go语言编写,无第三方依赖,经过测试内存占用小,普通场景下,仅占用10m内存。 @@ -135,12 +136,34 @@ server { 如需开启,请加配置文件Replace值设置为1 >注意:开启可能导致不应该被替换的内容被替换,请谨慎开启 +======= +# rproxy +简单的反向代理用于内网穿透 + +**特别注意,此工具只适合小文件类的访问测试,用来做做数据调试。当初也只是用于微信公众号开发,所以定位也是如此** + +## 前言 +最近周末闲来无事,想起了做下微信公共号的开发,但微信限制只能80端口的,自己用的城中村的那种宽带,共用一个公网,没办法自己用路由做端口映射。自己的服务器在腾讯云上,每次都要编译完后用ftp上传再进行调试,非常的浪费时间。 一时间又不知道上哪找一个符合我的这种要求的工具,就索性自己构思了下,整个工作流程大致为: + +## 工作原理 +> 外部请求自己服务器上的HTTP服务端 -> 将数据传递给Socket服务器 -> Socket服务器将数据发送至已连接的Socket客户端 -> Socket客户端收到数据 -> 使用http请求本地http服务端 -> 本地http服务端处理相关后返回 -> Socket客户端将返回的数据发送至Socket服务端 -> Socket服务端解析出数据后原路返回至外部请求的HTTP + +## 使用方法 +> 1、go get github.com/ying32/rproxy +> 2、go build +> 3、服务端运行runsvr.bat或者runsvr.sh +> 4、客户端运行runcli.bat或者runcli.sh + +## 命令行说明 +> --tcpport Socket连接或者监听的端口 +> --httpport 当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口 +> --mode 启动模式,可选为client、server,默认为client +> --svraddr 当mode为client时有效,为连接服务器的地址,不需要填写端口 +> --vkey 客户端与服务端建立连接时校验的加密key,简单的。 +>>>>>>> Stashed changes ## 操作系统支持 -支持Windows、Linux、MacOSX等,无第三方依赖库。 - -## 二级域名泛解析配置详细教程 - -[详细教程](https://github.com/cnlh/easyProxy/wiki/%E4%BD%BF%E7%94%A8%E6%95%99%E7%A8%8B) - +支持Windows、Linux、MacOSX等,无第三方依赖库。 +## 二进制下载 +https://github.com/ying32/rproxy/releases/tag/v0.4 diff --git a/client.go b/client.go index 483d17f..a43fa29 100755 --- a/client.go +++ b/client.go @@ -1,20 +1,15 @@ package main import ( - "encoding/binary" "errors" + "fmt" + "io" "log" "net" - "net/http" - "strings" "sync" "time" ) -var ( - disabledRedirect = errors.New("disabled redirect.") -) - type TRPClient struct { svrAddr string tcpNum int @@ -28,56 +23,58 @@ func NewRPClient(svraddr string, tcpNum int) *TRPClient { return c } -func (c *TRPClient) Start() error { - for i := 0; i < c.tcpNum; i++ { - go c.newConn() +func (s *TRPClient) Start() error { + for i := 0; i < s.tcpNum; i++ { + go s.newConn() } for { - time.Sleep(5 * time.Second) + time.Sleep(time.Second * 5) } return nil } -func (c *TRPClient) newConn() error { - c.Lock() - conn, err := net.Dial("tcp", c.svrAddr) +//新建 +func (s *TRPClient) newConn() error { + s.Lock() + conn, err := net.Dial("tcp", s.svrAddr) if err != nil { log.Println("连接服务端失败,五秒后将重连") time.Sleep(time.Second * 5) - c.Unlock() - c.newConn() + s.Unlock() + go s.newConn() return err } - c.Unlock() - conn.(*net.TCPConn).SetKeepAlive(true) - conn.(*net.TCPConn).SetKeepAlivePeriod(time.Duration(2 * time.Second)) - return c.process(conn) + s.Unlock() + return s.process(NewConn(conn)) } -func (c *TRPClient) werror(conn net.Conn) { - conn.Write([]byte("msg0")) -} - -func (c *TRPClient) process(conn net.Conn) error { - if _, err := conn.Write(getverifyval()); err != nil { +func (s *TRPClient) process(c *Conn) error { + c.SetAlive() + if _, err := c.Write(getverifyval()); err != nil { return err } - val := make([]byte, 4) + c.wMain() for { - _, err := conn.Read(val) + flags, err := c.ReadFlag() if err != nil { log.Println("服务端断开,五秒后将重连", err) time.Sleep(5 * time.Second) - go c.newConn() - return err + go s.newConn() + break } - flags := string(val) switch flags { - case "vkey": + case VERIFY_EER: log.Fatal("vkey不正确,请检查配置文件") - case "sign": - c.deal(conn) - case "msg0": + case RES_SIGN: //代理请求模式 + if err := s.dealHttp(c); err != nil { + log.Println(err) + return err + } + case WORK_CHAN: //隧道模式,每次开启10个,加快连接速度 + for i := 0; i < 10; i++ { + go s.dealChan() + } + case RES_MSG: log.Println("服务端返回错误。") default: log.Println("无法解析该错误。") @@ -85,69 +82,64 @@ func (c *TRPClient) process(conn net.Conn) error { } return nil } -func (c *TRPClient) deal(conn net.Conn) error { - val := make([]byte, 4) - _, err := conn.Read(val) - nlen := binary.LittleEndian.Uint32(val) - log.Println("收到服务端数据,长度:", nlen) - if nlen <= 0 { - log.Println("数据长度错误。") - c.werror(conn) - return errors.New("数据长度错误") + +//隧道模式处理 +func (s *TRPClient) dealChan() error { + //创建一个tcp连接 + conn, err := net.Dial("tcp", s.svrAddr) + //验证 + if _, err := conn.Write(getverifyval()); err != nil { + return err } - raw := make([]byte, nlen) - n, err := conn.Read(raw) + //默认长连接保持 + c := NewConn(conn) + c.SetAlive() + //写标志 + c.wChan() + //获取连接的host + host, err := c.GetHostFromConn() if err != nil { return err } - if n != int(nlen) { - log.Printf("读取服务端数据长度错误,已经读取%dbyte,总长度%d字节\n", n, nlen) - c.werror(conn) - return errors.New("读取服务端数据长度错误") + //与目标建立连接 + server, err := net.Dial("tcp", host) + if err != nil { + return err + } + //创建成功后io.copy + go io.Copy(server, c) + io.Copy(c, server) + return nil +} + +//http模式处理 +func (s *TRPClient) dealHttp(c *Conn) error { + nlen, err := c.GetLen() + if err != nil { + c.wError() + return err + } + raw, err := c.ReadLen(int(nlen)) + if err != nil { + c.wError() + return err } req, err := DecodeRequest(raw) if err != nil { - log.Println("DecodeRequest错误:", err) - c.werror(conn) + c.wError() return err } - rawQuery := "" - if req.URL.RawQuery != "" { - rawQuery = "?" + req.URL.RawQuery - } - log.Println(req.URL.Path + rawQuery) - client := new(http.Client) - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return disabledRedirect - } - resp, err := client.Do(req) - disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error()) - if err != nil && !disRedirect { - log.Println("请求本地客户端错误:", err) - c.werror(conn) - return err - } - if !disRedirect { - defer resp.Body.Close() - } else { - resp.Body = nil - resp.ContentLength = 0 - } - respBytes, err := EncodeResponse(resp) + respBytes, err := GetEncodeResponse(req) if err != nil { - log.Println("EncodeResponse错误:", err) - c.werror(conn) + c.wError() return err } - n, err = conn.Write(respBytes) + n, err := c.Write(respBytes) if err != nil { - log.Println("发送数据错误,错误:", err) return err } if n != len(respBytes) { - log.Printf("发送数据长度错误,已经发送:%dbyte,总字节长:%dbyte\n", n, len(respBytes)) - } else { - log.Printf("本次请求成功完成,共发送:%dbyte\n", n) + return errors.New(fmt.Sprintf("发送数据长度错误,已经发送:%dbyte,总字节长:%dbyte\n", n, len(respBytes))) } return nil } diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..97f256b --- /dev/null +++ b/conn.go @@ -0,0 +1,118 @@ +package main + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" +) + +type Conn struct { + conn net.Conn +} + +func NewConn(conn net.Conn) *Conn { + c := new(Conn) + c.conn = conn + return c +} + +//读取指定内容长度 +func (s *Conn) ReadLen(len int) ([]byte, error) { + raw := make([]byte, 0) + buff := make([]byte, 1024) + c := 0 + for { + clen, err := s.conn.Read(buff) + if err != nil && err != io.EOF { + return raw, err + } + raw = append(raw, buff[:clen]...) + if c += clen; c >= len { + break + } + } + if c != len { + return raw, errors.New(fmt.Sprintf("已读取长度错误,已读取%dbyte,需要读取%dbyte。", c, len)) + } + return raw, nil +} + +//获取长度 +func (s *Conn) GetLen() (int, error) { + val := make([]byte, 4) + _, err := s.conn.Read(val) + if err != nil { + return 0, err + } + nlen := binary.LittleEndian.Uint32(val) + if nlen <= 0 { + return 0, errors.New("数据长度错误") + } + return int(nlen), nil +} + +//读取flag +func (s *Conn) ReadFlag() (string, error) { + val := make([]byte, 4) + _, err := s.conn.Read(val) + if err != nil { + return "", err + } + return string(val), err +} + +//读取host +func (s *Conn) GetHostFromConn() (string, error) { + len, err := s.GetLen() + if err != nil { + return "", err + } + hostByte := make([]byte, len) + _, err = s.conn.Read(hostByte) + if err != nil { + return "", err + } + return string(hostByte), nil +} + +//获取host +func (s *Conn) WriteHost(host string) (int, error) { + raw := bytes.NewBuffer([]byte{}) + binary.Write(raw, binary.LittleEndian, int32(len([]byte(host)))) + binary.Write(raw, binary.LittleEndian, []byte(host)) + return s.Write(raw.Bytes()) +} + +//设置连接为长连接 +func (s *Conn) SetAlive() { + conn := s.conn.(*net.TCPConn) + conn.SetReadDeadline(time.Time{}) + conn.SetKeepAlive(true) + conn.SetKeepAlivePeriod(time.Duration(2 * time.Second)) +} + +func (s *Conn) Close() error { + return s.conn.Close() +} +func (s *Conn) Write(b []byte) (int, error) { + return s.conn.Write(b) +} +func (s *Conn) Read(b []byte) (int, error) { + return s.conn.Read(b) +} + +func (s *Conn) wError() { + s.conn.Write([]byte(RES_MSG)) +} + +func (s *Conn) wMain() { + s.conn.Write([]byte(WORK_MAIN)) +} + +func (s *Conn) wChan() { + s.conn.Write([]byte(WORK_CHAN)) +} diff --git a/main.go b/main.go index b761a27..951a38e 100755 --- a/main.go +++ b/main.go @@ -7,13 +7,14 @@ import ( ) var ( - configPath = flag.String("config", "config.json", "配置文件路径") - tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口") - httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口") - rpMode = flag.String("mode", "client", "启动模式,可选为client、server") - verifyKey = flag.String("vkey", "", "验证密钥") - config Config - err error + configPath = flag.String("config", "config.json", "配置文件路径") + tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口") + httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口") + rpMode = flag.String("mode", "client", "启动模式,可选为client、server") + tunnelTarget = flag.String("target", "10.1.50.203:80", "tunnel模式远程目标") + verifyKey = flag.String("vkey", "", "验证密钥") + config Config + err error ) func main() { @@ -29,7 +30,7 @@ func main() { log.Println("客户端启动,连接:", config.Server.Ip, ", 端口:", config.Server.Tcp) cli := NewRPClient(fmt.Sprintf("%s:%d", config.Server.Ip, config.Server.Tcp), config.Server.Num) cli.Start() - } else if *rpMode == "server" { + } else { if *verifyKey == "" { log.Fatalln("必须输入一个验证的key") } @@ -39,11 +40,20 @@ func main() { if *httpPort <= 0 || *httpPort >= 65536 { log.Fatalln("请输入正确的http端口。") } - log.Println("服务端启动,监听tcp服务端端口:", *tcpPort, ", http服务端端口:", *httpPort) - svr := NewRPServer(*tcpPort, *httpPort) - if err := svr.Start(); err != nil { - log.Fatalln(err) + log.Println("服务端启动,监听tcp服务端端口:", *tcpPort, ", 外部服务端端口:", *httpPort) + if *rpMode == "httpServer" { + svr := NewHttpModeServer(*tcpPort, *httpPort) + if err := svr.Start(); err != nil { + log.Fatalln(err) + } + } else if *rpMode == "tunnelServer" { + svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget) + if err := svr.Start(); err != nil { + log.Fatalln(err) + } + } else if *rpMode == "sock5Server" { + svr := NewSock5ModeServer(*tcpPort, *httpPort) + svr.Start() } - defer svr.Close() } } diff --git a/server.go b/server.go index 1b91f3a..ad6b857 100755 --- a/server.go +++ b/server.go @@ -1,8 +1,6 @@ package main import ( - "bytes" - "encoding/binary" "errors" "fmt" "io" @@ -10,115 +8,69 @@ import ( "log" "net" "net/http" - "sync" - "time" ) -type TRPServer struct { - tcpPort int +const ( + VERIFY_EER = "vkey" + WORK_MAIN = "main" + WORK_CHAN = "chan" + RES_SIGN = "sign" + RES_MSG = "msg0" +) + +type HttpModeServer struct { + Tunnel httpPort int - listener *net.TCPListener - connList chan net.Conn - sync.RWMutex } -func NewRPServer(tcpPort, httpPort int) *TRPServer { - s := new(TRPServer) - s.tcpPort = tcpPort +func NewHttpModeServer(tcpPort, httpPort int) *HttpModeServer { + s := new(HttpModeServer) + s.tunnelPort = tcpPort s.httpPort = httpPort - s.connList = make(chan net.Conn, 1000) + s.signalList = make(chan *Conn, 1000) return s } -func (s *TRPServer) Start() error { - var err error - s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""}) +//开始 +func (s *HttpModeServer) Start() (error) { + err := s.StartTunnel() if err != nil { + log.Fatalln("开启客户端失败!", err) return err } - go s.httpserver() - return s.tcpserver() + s.startHttpServer() + return nil } -func (s *TRPServer) Close() error { - if s.listener != nil { - err := s.listener.Close() - s.listener = nil - return err - } - return errors.New("TCP实例未创建!") -} - -func (s *TRPServer) tcpserver() error { - var err error - for { - conn, err := s.listener.AcceptTCP() - if err != nil { - log.Println(err) - continue - } - go s.cliProcess(conn) - } - return err -} - -func badRequest(w http.ResponseWriter) { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) -} - -func (s *TRPServer) httpserver() { +//开启http端口监听 +func (s *HttpModeServer) startHttpServer() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { retry: - if len(s.connList) == 0 { - badRequest(w) + if len(s.signalList) == 0 { + BadRequest(w) return } - conn := <-s.connList - log.Println(r.RequestURI) - err := s.write(r, conn) + conn := <-s.signalList + if err := s.writeRequest(r, conn); err != nil { + log.Println(err) + conn.Close() + goto retry + return + } + err = s.writeResponse(w, conn) if err != nil { log.Println(err) conn.Close() goto retry return } - err = s.read(w, conn) - if err != nil { - log.Println(err) - conn.Close() - goto retry - return - } - s.connList <- conn - conn = nil + s.signalList <- conn }) log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil)) } -func (s *TRPServer) cliProcess(conn *net.TCPConn) error { - conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second)) - vval := make([]byte, 20) - _, err := conn.Read(vval) - if err != nil { - log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr()) - conn.Close() - return err - } - if bytes.Compare(vval, getverifyval()[:]) != 0 { - log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr()) - conn.Write([]byte("vkey")) - conn.Close() - return err - } - conn.SetReadDeadline(time.Time{}) - log.Println("连接新的客户端:", conn.RemoteAddr()) - conn.SetKeepAlive(true) - conn.SetKeepAlivePeriod(time.Duration(2 * time.Second)) - s.connList <- conn - return nil -} - -func (s *TRPServer) write(r *http.Request, conn net.Conn) error { +//req转为bytes发送给client端 +func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error { raw, err := EncodeRequest(r) if err != nil { return err @@ -133,41 +85,21 @@ func (s *TRPServer) write(r *http.Request, conn net.Conn) error { return nil } -func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) { - val := make([]byte, 4) - _, err := conn.Read(val) +//从client读取出Response +func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error { + flags, err := c.ReadFlag() if err != nil { return err } - flags := string(val) switch flags { - case "sign": - _, err = conn.Read(val) + case RES_SIGN: + nlen, err := c.GetLen() if err != nil { return err } - nlen := int(binary.LittleEndian.Uint32(val)) - if nlen == 0 { - return errors.New("读取客户端长度错误。") - } - log.Println("收到客户端数据,需要读取长度:", nlen) - raw := make([]byte, 0) - buff := make([]byte, 1024) - c := 0 - for { - clen, err := conn.Read(buff) - if err != nil && err != io.EOF { - return err - } - raw = append(raw, buff[:clen]...) - c += clen - if c >= nlen { - break - } - } - log.Println("读取完成,长度:", c, "实际raw长度:", len(raw)) - if c != nlen { - return fmt.Errorf("已读取长度错误,已读取%dbyte,需要读取%dbyte。", c, nlen) + raw, err := c.ReadLen(nlen) + if err != nil { + return err } resp, err := DecodeResponse(raw) if err != nil { @@ -184,10 +116,70 @@ func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) { } w.WriteHeader(resp.StatusCode) w.Write(bodyBytes) - case "msg0": - return nil + case RES_MSG: + BadRequest(w) + return errors.New("客户端请求出错") default: - log.Println("无法解析此错误", string(val)) + BadRequest(w) + return errors.New("无法解析此错误") } return nil } + +type TunnelModeServer struct { + Tunnel + httpPort int + tunnelTarget string +} + +func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string) *TunnelModeServer { + s := new(TunnelModeServer) + s.tunnelPort = tcpPort + s.httpPort = httpPort + s.tunnelTarget = tunnelTarget + s.tunnelList = make(chan *Conn, 1000) + s.signalList = make(chan *Conn, 10) + return s +} + +//开始 +func (s *TunnelModeServer) Start() (error) { + err := s.StartTunnel() + if err != nil { + log.Fatalln("开启客户端失败!", err) + return err + } + s.startTunnelServer() + return nil +} + +//隧道模式server +func (s *TunnelModeServer) startTunnelServer() { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.httpPort, ""}) + if err != nil { + log.Fatalln(err) + } + for { + conn, err := listener.AcceptTCP() + if err != nil { + log.Println(err) + continue + } + go s.process(NewConn(conn)) + } +} + +//监听连接处理 +func (s *TunnelModeServer) process(c *Conn) error { +retry: + if len(s.tunnelList) < 10 { //新建通道 + go s.newChan() + } + link := <-s.tunnelList + if _, err := link.WriteHost(s.tunnelTarget); err != nil { + goto retry + } + go io.Copy(link, c) + io.Copy(c, link.conn) + return nil +} diff --git a/sock5.go b/sock5.go new file mode 100644 index 0000000..e6cd575 --- /dev/null +++ b/sock5.go @@ -0,0 +1,236 @@ +package main + +import ( + "encoding/binary" + "errors" + "io" + "log" + "net" + "strconv" +) + +const ( + ipV4 = 1 + domainName = 3 + ipV6 = 4 + connectMethod = 1 + bindMethod = 2 + associateMethod = 3 + // The maximum packet size of any udp Associate packet, based on ethernet's max size, + // minus the IP and UDP headers. IPv4 has a 20 byte header, UDP adds an + // additional 4 bytes. This is a total overhead of 24 bytes. Ethernet's + // max packet size is 1500 bytes, 1500 - 24 = 1476. + maxUDPPacketSize = 1476 +) + +const ( + succeeded uint8 = iota + serverFailure + notAllowed + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addrTypeNotSupported +) + +type Sock5ModeServer struct { + Tunnel + httpPort int +} + +func (s *Sock5ModeServer) handleRequest(c net.Conn) { + /* + The SOCKS request is formed as follows: + +----+-----+-------+------+----------+----------+ + |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | + +----+-----+-------+------+----------+----------+ + | 1 | 1 | X'00' | 1 | Variable | 2 | + +----+-----+-------+------+----------+----------+ + */ + header := make([]byte, 3) + + _, err := io.ReadFull(c, header) + + if err != nil { + log.Println("illegal request", err) + c.Close() + return + } + + switch header[1] { + case connectMethod: + s.handleConnect(c) + case bindMethod: + s.handleBind(c) + case associateMethod: + s.handleUDP(c) + default: + s.sendReply(c, commandNotSupported) + c.Close() + } +} + +func (s *Sock5ModeServer) sendReply(c net.Conn, rep uint8) { + reply := []byte{ + 5, + rep, + 0, + 1, + } + + localAddr := c.LocalAddr().String() + localHost, localPort, _ := net.SplitHostPort(localAddr) + ipBytes := net.ParseIP(localHost).To4() + nPort, _ := strconv.Atoi(localPort) + reply = append(reply, ipBytes...) + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(nPort)) + reply = append(reply, portBytes...) + + c.Write(reply) +} + +func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) (proxyConn *Conn, err error) { + addrType := make([]byte, 1) + c.Read(addrType) + var host string + switch addrType[0] { + case ipV4: + ipv4 := make(net.IP, net.IPv4len) + c.Read(ipv4) + host = ipv4.String() + case ipV6: + ipv6 := make(net.IP, net.IPv6len) + c.Read(ipv6) + host = ipv6.String() + case domainName: + var domainLen uint8 + binary.Read(c, binary.BigEndian, &domainLen) + domain := make([]byte, domainLen) + c.Read(domain) + host = string(domain) + default: + s.sendReply(c, addrTypeNotSupported) + err = errors.New("Address type not supported") + return nil, err + } + + var port uint16 + binary.Read(c, binary.BigEndian, &port) + + // connect to host + addr := net.JoinHostPort(host, strconv.Itoa(int(port))) + //取出一个连接 + if len(s.tunnelList) < 10 { //新建通道 + go s.newChan() + } + client := <-s.tunnelList + s.sendReply(c, succeeded) + _, err = client.WriteHost(addr) + return client, nil +} + +func (s *Sock5ModeServer) handleConnect(c net.Conn) { + proxyConn, err := s.doConnect(c, connectMethod) + if err != nil { + c.Close() + } else { + go io.Copy(c, proxyConn) + go io.Copy(proxyConn, c) + } + +} + +func (s *Sock5ModeServer) relay(in, out net.Conn) { + if _, err := io.Copy(in, out); err != nil { + log.Println("copy error", err) + } + in.Close() // will trigger an error in the other relay, then call out.Close() +} + +// passive mode +func (s *Sock5ModeServer) handleBind(c net.Conn) { +} + +func (s *Sock5ModeServer) handleUDP(c net.Conn) { + log.Println("UDP Associate") + /* + +----+------+------+----------+----------+----------+ + |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | + +----+------+------+----------+----------+----------+ + | 2 | 1 | 1 | Variable | 2 | Variable | + +----+------+------+----------+----------+----------+ + */ + buf := make([]byte, 3) + c.Read(buf) + // relay udp datagram silently, without any notification to the requesting client + if buf[2] != 0 { + // does not support fragmentation, drop it + log.Println("does not support fragmentation, drop") + dummy := make([]byte, maxUDPPacketSize) + c.Read(dummy) + } + + proxyConn, err := s.doConnect(c, associateMethod) + if err != nil { + c.Close() + } else { + go io.Copy(c, proxyConn) + go io.Copy(proxyConn, c) + } +} + +func (s *Sock5ModeServer) handleNewConn(c net.Conn) { + buf := make([]byte, 2) + if _, err := io.ReadFull(c, buf); err != nil { + log.Println("negotiation err", err) + c.Close() + return + } + + if version := buf[0]; version != 5 { + log.Println("only support socks5, request from: ", c.RemoteAddr()) + c.Close() + return + } + nMethods := buf[1] + + methods := make([]byte, nMethods) + if len, err := c.Read(methods); len != int(nMethods) || err != nil { + log.Println("wrong method") + c.Close() + return + } + // no authentication required for now + buf[1] = 0 + // send a METHOD selection message + c.Write(buf) + + s.handleRequest(c) +} + +func (s *Sock5ModeServer) Start() { + l, err := net.Listen("tcp", ":"+strconv.Itoa(s.httpPort)) + if err != nil { + log.Fatal("listen error: ", err) + } + s.StartTunnel() + for { + conn, err := l.Accept() + if err != nil { + log.Fatal("accept error: ", err) + } + go s.handleNewConn(conn) + } +} + +func NewSock5ModeServer(tcpPort, httpPort int) *Sock5ModeServer { + s := new(Sock5ModeServer) + s.tunnelPort = tcpPort + s.httpPort = httpPort + s.tunnelList = make(chan *Conn, 1000) + s.signalList = make(chan *Conn, 10) + return s +} diff --git a/tunnel.go b/tunnel.go new file mode 100644 index 0000000..1d8342f --- /dev/null +++ b/tunnel.go @@ -0,0 +1,97 @@ +package main + +import ( + "bytes" + "errors" + "fmt" + "log" + "net" + "sync" + "time" +) + +type Tunnel struct { + tunnelPort int //通信隧道端口 + listener *net.TCPListener //server端监听 + signalList chan *Conn //通信 + tunnelList chan *Conn //隧道 + sync.RWMutex +} + +func (s *Tunnel) StartTunnel() error { + var err error + s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tunnelPort, ""}) + if err != nil { + return err + } + go s.tunnelProcess() + return nil +} + +//tcp server +func (s *Tunnel) tunnelProcess() error { + var err error + for { + conn, err := s.listener.Accept() + if err != nil { + log.Println(err) + continue + } + go s.cliProcess(NewConn(conn)) + } + return err +} + +//验证失败,返回错误验证flag,并且关闭连接 +func (s *Tunnel) verifyError(c *Conn) { + c.conn.Write([]byte(VERIFY_EER)) + c.conn.Close() +} + +func (s *Tunnel) cliProcess(c *Conn) error { + c.conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second)) + vval := make([]byte, 20) + _, err := c.conn.Read(vval) + if err != nil { + log.Println("客户端读超时。客户端地址为::", c.conn.RemoteAddr()) + c.conn.Close() + return err + } + if bytes.Compare(vval, getverifyval()[:]) != 0 { + log.Println("当前客户端连接校验错误,关闭此客户端:", c.conn.RemoteAddr()) + s.verifyError(c) + return err + } + //做一个判断 添加到对应的channel里面以供使用 + flag, err := c.ReadFlag() + if err != nil { + return err + } + return s.typeDeal(flag, c) +} + +//tcp连接类型区分 +func (s *Tunnel) typeDeal(typeVal string, c *Conn) error { + switch typeVal { + case WORK_MAIN: + s.signalList <- c + case WORK_CHAN: + s.tunnelList <- c + default: + return errors.New("无法识别") + } + c.SetAlive() + return nil +} + +//新建隧道 +func (s *Tunnel) newChan() { +retry: + connPass := <-s.signalList + _, err := connPass.conn.Write([]byte("chan")) + if err != nil { + fmt.Println(err) + goto retry + } + s.signalList <- connPass +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..4e95278 --- /dev/null +++ b/util.go @@ -0,0 +1,148 @@ +package main + +import ( + "bufio" + "bytes" + "compress/gzip" + "encoding/binary" + "errors" + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" +) + +var ( + disabledRedirect = errors.New("disabled redirect.") +) + + + + +func BadRequest(w http.ResponseWriter) { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) +} + + + +//发送请求并转为bytes +func GetEncodeResponse(req *http.Request) ([]byte, error) { + var respBytes []byte + client := new(http.Client) + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return disabledRedirect + } + resp, err := client.Do(req) + disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error()) + if err != nil && !disRedirect { + return respBytes, err + } + if !disRedirect { + defer resp.Body.Close() + } else { + resp.Body = nil + resp.ContentLength = 0 + } + respBytes, err = EncodeResponse(resp) + return respBytes, nil +} + + +// 将request 的处理 +func EncodeRequest(r *http.Request) ([]byte, error) { + raw := bytes.NewBuffer([]byte{}) + // 写签名 + binary.Write(raw, binary.LittleEndian, []byte("sign")) + reqBytes, err := httputil.DumpRequest(r, true) + if err != nil { + return nil, err + } + // 写body数据长度 + 1 + binary.Write(raw, binary.LittleEndian, int32(len(reqBytes)+1)) + // 判断是否为http或者https的标识1字节 + binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https")) + if err := binary.Write(raw, binary.LittleEndian, reqBytes); err != nil { + return nil, err + } + return raw.Bytes(), nil +} + +// 将字节转为request +func DecodeRequest(data []byte) (*http.Request, error) { + if len(data) <= 100 { + return nil, errors.New("待解码的字节长度太小") + } + req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:]))) + if err != nil { + return nil, err + } + str := strings.Split(req.Host, ":") + req.Host, err = getHost(str[0]) + if err != nil { + return nil, err + } + scheme := "http" + if data[0] == 1 { + scheme = "https" + } + req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI)) + req.RequestURI = "" + return req, nil +} + +//// 将response转为字节 +func EncodeResponse(r *http.Response) ([]byte, error) { + raw := bytes.NewBuffer([]byte{}) + binary.Write(raw, binary.LittleEndian, []byte(RES_SIGN)) + respBytes, err := httputil.DumpResponse(r, true) + if config.Replace == 1 { + respBytes = replaceHost(respBytes) + } + if err != nil { + return nil, err + } + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + zw.Write(respBytes) + zw.Close() + binary.Write(raw, binary.LittleEndian, int32(len(buf.Bytes()))) + if err := binary.Write(raw, binary.LittleEndian, buf.Bytes()); err != nil { + fmt.Println(err) + return nil, err + } + return raw.Bytes(), nil +} + +// 将字节转为response +func DecodeResponse(data []byte) (*http.Response, error) { + zr, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + defer zr.Close() + resp, err := http.ReadResponse(bufio.NewReader(zr), nil) + if err != nil { + return nil, err + } + return resp, nil +} + +func getHost(str string) (string, error) { + for _, v := range config.SiteList { + if v.Host == str { + return v.Url + ":" + strconv.Itoa(v.Port), nil + } + } + return "", errors.New("没有找到解析的的host!") +} + +func replaceHost(resp []byte) []byte { + str := string(resp) + for _, v := range config.SiteList { + str = strings.Replace(str, v.Url+":"+strconv.Itoa(v.Port), v.Host, -1) + str = strings.Replace(str, v.Url, v.Host, -1) + } + return []byte(str) +}