diff --git a/server/proxy/http.go b/server/proxy/http.go index 4d6ae81..78a26b6 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -1,7 +1,7 @@ package proxy import ( - "context" + "bufio" "crypto/tls" "io" "net" @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "strconv" + "strings" "sync" - "time" "ehang.io/nps/bridge" "ehang.io/nps/lib/cache" @@ -101,159 +101,174 @@ func (s *httpServer) Close() error { return nil } +func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) { + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Hijacking not supported", http.StatusInternalServerError) + return + } + c, _, err := hijacker.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + } + s.handleHttp(conn.NewConn(c), r) +} + +func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) { + var ( + host *file.Host + target net.Conn + err error + connClient io.ReadWriteCloser + scheme = r.URL.Scheme + lk *conn.Link + targetAddr string + lenConn *conn.LenConn + isReset bool + wg sync.WaitGroup + ) + defer func() { + if connClient != nil { + connClient.Close() + } else { + s.writeConnFail(c.Conn) + } + c.Close() + }() +reset: + if isReset { + host.Client.AddConn() + } + if host, err = file.GetDb().GetInfoByHost(r.Host, r); err != nil { + logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) + return + } + if err := s.CheckFlowAndConnNum(host.Client); err != nil { + logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error()) + return + } + if !isReset { + defer host.Client.AddConn() + } + if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil { + logs.Warn("auth error", err, r.RemoteAddr) + return + } + if targetAddr, err = host.Target.GetRandomTarget(); err != nil { + logs.Warn(err.Error()) + return + } + lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) + if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { + logs.Notice("connect to target %s error %s", lk.Host, err) + return + } + connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + + //read from inc-client + go func() { + wg.Add(1) + isReset = false + defer connClient.Close() + defer func() { + wg.Done() + if !isReset { + c.Close() + } + }() + for { + if resp, err := http.ReadResponse(bufio.NewReader(connClient), r); err != nil || resp == nil { + return + } else { + //if the cache is start and the response is in the extension,store the response to the cache list + if s.useCache && r.URL != nil && strings.Contains(r.URL.Path, ".") { + b, err := httputil.DumpResponse(resp, true) + if err != nil { + return + } + c.Write(b) + host.Flow.Add(0, int64(len(b))) + s.cache.Add(filepath.Join(host.Host, r.URL.Path), b) + } else { + lenConn := conn.NewLenConn(c) + if err := resp.Write(lenConn); err != nil { + logs.Error(err) + return + } + host.Flow.Add(0, int64(lenConn.Len)) + } + } + } + }() + + for { + //if the cache start and the request is in the cache list, return the cache + if s.useCache { + if v, ok := s.cache.Get(filepath.Join(host.Host, r.URL.Path)); ok { + n, err := c.Write(v.([]byte)) + if err != nil { + break + } + logs.Trace("%s request, method %s, host %s, url %s, remote address %s, return cache", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String()) + host.Flow.Add(0, int64(n)) + //if return cache and does not create a new conn with client and Connection is not set or close, close the connection. + if strings.ToLower(r.Header.Get("Connection")) == "close" || strings.ToLower(r.Header.Get("Connection")) == "" { + break + } + goto readReq + } + } + + //change the host and header and set proxy setting + common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String(), s.addOrigin) + logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String(), lk.Host) + //write + lenConn = conn.NewLenConn(connClient) + if err := r.Write(lenConn); err != nil { + logs.Error(err) + break + } + host.Flow.Add(int64(lenConn.Len), 0) + + readReq: + //read req from connection + if r, err = http.ReadRequest(bufio.NewReader(c)); err != nil { + break + } + r.URL.Scheme = scheme + //What happened ,Why one character less??? + r.Method = resetReqMethod(r.Method) + if hostTmp, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil { + logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) + break + } else if host != hostTmp { + host = hostTmp + isReset = true + connClient.Close() + goto reset + } + } + wg.Wait() +} + +func resetReqMethod(method string) string { + if method == "ET" { + return "GET" + } + if method == "OST" { + return "POST" + } + return method +} + func (s *httpServer) NewServer(port int, scheme string) *http.Server { - rProxy := NewHttpReverseProxy(s) return &http.Server{ Addr: ":" + strconv.Itoa(port), Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = scheme - rProxy.ServeHTTP(w, r) + s.handleTunneling(w, r) }), // Disable HTTP/2. TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } } - -type HttpReverseProxy struct { - proxy *ReverseProxy - - responseHeaderTimeout time.Duration -} - -func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - var ( - host *file.Host - targetAddr string - err error - ) - if host, err = file.GetDb().GetInfoByHost(req.Host, req); err != nil { - rw.WriteHeader(http.StatusNotFound) - rw.Write([]byte(req.Host + " not found")) - return - } - if host.Client.Cnf.U != "" && host.Client.Cnf.P != "" && !common.CheckAuth(req, host.Client.Cnf.U, host.Client.Cnf.P) { - rw.WriteHeader(http.StatusUnauthorized) - rw.Write([]byte("Unauthorized")) - return - } - if targetAddr, err = host.Target.GetRandomTarget(); err != nil { - rw.WriteHeader(http.StatusBadGateway) - rw.Write([]byte("502 Bad Gateway")) - return - } - req = req.WithContext(context.WithValue(req.Context(), "host", host)) - req = req.WithContext(context.WithValue(req.Context(), "target", targetAddr)) - req = req.WithContext(context.WithValue(req.Context(), "req", req)) - - rp.proxy.ServeHTTP(rw, req) -} - -func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy { - rp := &HttpReverseProxy{ - responseHeaderTimeout: 30 * time.Second, - } - local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") - proxy := NewReverseProxy(&httputil.ReverseProxy{ - Director: func(r *http.Request) { - host := r.Context().Value("host").(*file.Host) - common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false) - }, - Transport: &http.Transport{ - ResponseHeaderTimeout: rp.responseHeaderTimeout, - DisableKeepAlives: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - var ( - host *file.Host - target net.Conn - err error - connClient io.ReadWriteCloser - targetAddr string - lk *conn.Link - ) - - r := ctx.Value("req").(*http.Request) - host = ctx.Value("host").(*file.Host) - targetAddr = ctx.Value("target").(string) - - lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) - if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { - logs.Notice("connect to target %s error %s", lk.Host, err) - return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the server") - } - connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) - return &flowConn{ - ReadWriteCloser: connClient, - fakeAddr: local, - host: host, - }, nil - }, - }, - ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - logs.Warn("do http proxy request error: %v", err) - rw.WriteHeader(http.StatusNotFound) - }, - }) - proxy.WebSocketDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - var ( - host *file.Host - target net.Conn - err error - connClient io.ReadWriteCloser - targetAddr string - lk *conn.Link - ) - r := ctx.Value("req").(*http.Request) - host = ctx.Value("host").(*file.Host) - targetAddr = ctx.Value("target").(string) - - lk = conn.NewLink("tcp", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) - if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { - logs.Notice("connect to target %s error %s", lk.Host, err) - return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the target") - } - connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) - return &flowConn{ - ReadWriteCloser: connClient, - fakeAddr: local, - host: host, - }, nil - } - rp.proxy = proxy - return rp -} - -type flowConn struct { - io.ReadWriteCloser - fakeAddr net.Addr - host *file.Host - flowIn int64 - flowOut int64 - once sync.Once -} - -func (c *flowConn) Read(p []byte) (n int, err error) { - n, err = c.ReadWriteCloser.Read(p) - c.flowIn += int64(n) - return n, err -} - -func (c *flowConn) Write(p []byte) (n int, err error) { - n, err = c.ReadWriteCloser.Write(p) - c.flowOut += int64(n) - return n, err -} - -func (c *flowConn) Close() error { - c.once.Do(func() { c.host.Flow.Add(c.flowIn, c.flowOut) }) - return c.ReadWriteCloser.Close() -} - -func (c *flowConn) LocalAddr() net.Addr { return c.fakeAddr } - -func (c *flowConn) RemoteAddr() net.Addr { return c.fakeAddr } - -func (*flowConn) SetDeadline(t time.Time) error { return nil } - -func (*flowConn) SetReadDeadline(t time.Time) error { return nil } - -func (*flowConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/server/proxy/reverseproxy.go b/server/proxy/reverseproxy.go deleted file mode 100644 index df7e866..0000000 --- a/server/proxy/reverseproxy.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// HTTP reverse proxy handler - -package proxy - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "net/http/httputil" - "net/url" - "strings" - "sync" -) - -type HTTPError struct { - error - HTTPCode int -} - -func NewHTTPError(code int, errmsg string) error { - return &HTTPError{ - error: errors.New(errmsg), - HTTPCode: code, - } -} - -type ReverseProxy struct { - *httputil.ReverseProxy - WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error) -} - -func IsWebsocketRequest(req *http.Request) bool { - containsHeader := func(name, value string) bool { - items := strings.Split(req.Header.Get(name), ",") - for _, item := range items { - if value == strings.ToLower(strings.TrimSpace(item)) { - return true - } - } - return false - } - return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket") -} - -func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - rp := &ReverseProxy{ - ReverseProxy: httputil.NewSingleHostReverseProxy(target), - WebSocketDialContext: nil, - } - rp.ErrorHandler = rp.errHandler - return rp -} - -func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy { - rp := &ReverseProxy{ - ReverseProxy: orp, - WebSocketDialContext: nil, - } - rp.ErrorHandler = rp.errHandler - return rp -} - -func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) { - if e == io.EOF { - rw.WriteHeader(521) - //rw.Write(getWaitingPageContent()) - } else { - if httperr, ok := e.(*HTTPError); ok { - rw.WriteHeader(httperr.HTTPCode) - } else { - rw.WriteHeader(http.StatusNotFound) - } - rw.Write([]byte("error: " + e.Error())) - } -} - -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if IsWebsocketRequest(req) { - p.serveWebSocket(rw, req) - } else { - p.ReverseProxy.ServeHTTP(rw, req) - } -} - -func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) { - if p.WebSocketDialContext == nil { - rw.WriteHeader(500) - return - } - targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "") - if err != nil { - rw.WriteHeader(501) - return - } - defer targetConn.Close() - - p.Director(req) - - hijacker, ok := rw.(http.Hijacker) - if !ok { - rw.WriteHeader(500) - return - } - conn, _, errHijack := hijacker.Hijack() - if errHijack != nil { - rw.WriteHeader(500) - return - } - defer conn.Close() - - req.Write(targetConn) - Join(conn, targetConn) -} - -func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) { - var wait sync.WaitGroup - pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) { - defer to.Close() - defer from.Close() - defer wait.Done() - - *count, _ = io.Copy(to, from) - } - - wait.Add(2) - go pipe(c1, c2, &inCount) - go pipe(c2, c1, &outCount) - wait.Wait() - return -}