mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-02 11:56:53 +00:00
Code optimization
This commit is contained in:
@@ -28,8 +28,9 @@ type httpServer struct {
|
||||
httpsPort int //https监听端口
|
||||
pemPath string
|
||||
keyPath string
|
||||
stop chan bool
|
||||
httpslistener net.Listener
|
||||
httpServer *http.Server
|
||||
httpsServer *http.Server
|
||||
httpsListener net.Listener
|
||||
}
|
||||
|
||||
func NewHttp(bridge *bridge.Bridge, c *file.Tunnel) *httpServer {
|
||||
@@ -47,7 +48,6 @@ func NewHttp(bridge *bridge.Bridge, c *file.Tunnel) *httpServer {
|
||||
httpsPort: httpsPort,
|
||||
pemPath: pemPath,
|
||||
keyPath: keyPath,
|
||||
stop: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,13 +58,17 @@ func (s *httpServer) processHttps(c net.Conn) {
|
||||
return
|
||||
}
|
||||
var host *file.Host
|
||||
file.GetCsvDb().Lock()
|
||||
for _, v := range file.GetCsvDb().Hosts {
|
||||
file.GetCsvDb().Hosts.Range(func(key, value interface{}) bool {
|
||||
v := value.(*file.Host)
|
||||
if v.Scheme != "https" && v.Scheme != "all" {
|
||||
return true
|
||||
}
|
||||
if bytes.Index(buf[:n], []byte(v.Host)) >= 0 && (host == nil || len(host.Host) < len(v.Host)) {
|
||||
host = v
|
||||
return false
|
||||
}
|
||||
}
|
||||
file.GetCsvDb().Unlock()
|
||||
return true
|
||||
})
|
||||
if host == nil {
|
||||
logs.Error("new https connection can't be parsed!", c.RemoteAddr().String())
|
||||
c.Close()
|
||||
@@ -76,40 +80,37 @@ func (s *httpServer) processHttps(c net.Conn) {
|
||||
r.URL = new(url.URL)
|
||||
r.URL.Scheme = "https"
|
||||
r.Host = host.Host
|
||||
//read the host form connection
|
||||
if !host.Client.GetConn() { //conn num limit
|
||||
logs.Notice("connections exceed the current client %d limit %d ,now connection num %d", host.Client.Id, host.Client.MaxConn, host.Client.NowConn)
|
||||
if err := s.CheckFlowAndConnNum(host.Client); err != nil {
|
||||
logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error())
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
//流量限制
|
||||
if host.Client.Flow.FlowLimit > 0 && (host.Client.Flow.FlowLimit<<20) < (host.Client.Flow.ExportFlow+host.Client.Flow.InletFlow) {
|
||||
logs.Warn("Traffic exceeded client id %s", host.Client.Id)
|
||||
if err = s.auth(r, conn.NewConn(c), host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
|
||||
logs.Warn("auth error", err, r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
if targetAddr, err = host.GetRandomTarget(); err != nil {
|
||||
logs.Warn(err.Error())
|
||||
}
|
||||
logs.Trace("new https connection,clientId %d,host %s,remote address %s", host.Client.Id, r.Host, c.RemoteAddr().String())
|
||||
s.DealClient(conn.NewConn(c), host.Client, targetAddr, buf[:n], common.CONN_TCP)
|
||||
s.DealClient(conn.NewConn(c), host.Client, targetAddr, buf[:n], common.CONN_TCP, nil)
|
||||
}
|
||||
|
||||
func (s *httpServer) Start() error {
|
||||
var err error
|
||||
var httpSrv, httpsSrv *http.Server
|
||||
if s.errorContent, err = common.ReadAllFromFile(filepath.Join(common.GetRunPath(), "web", "static", "page", "error.html")); err != nil {
|
||||
s.errorContent = []byte("easyProxy 404")
|
||||
}
|
||||
|
||||
if s.httpPort > 0 {
|
||||
httpSrv = s.NewServer(s.httpPort, "http")
|
||||
s.httpServer = s.NewServer(s.httpPort, "http")
|
||||
go func() {
|
||||
l, err := connection.GetHttpListener()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
os.Exit(0)
|
||||
}
|
||||
err = httpSrv.Serve(l)
|
||||
err = s.httpServer.Serve(l)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
os.Exit(0)
|
||||
@@ -117,23 +118,16 @@ func (s *httpServer) Start() error {
|
||||
}()
|
||||
}
|
||||
if s.httpsPort > 0 {
|
||||
if !common.FileExists(s.pemPath) {
|
||||
os.Exit(0)
|
||||
}
|
||||
if !common.FileExists(s.keyPath) {
|
||||
logs.Error("ssl keyFile %s exist", s.keyPath)
|
||||
os.Exit(0)
|
||||
}
|
||||
httpsSrv = s.NewServer(s.httpsPort, "https")
|
||||
s.httpsServer = s.NewServer(s.httpsPort, "https")
|
||||
go func() {
|
||||
l, err := connection.GetHttpsListener()
|
||||
s.httpsListener, err = connection.GetHttpsListener()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
os.Exit(0)
|
||||
}
|
||||
if b, err := beego.AppConfig.Bool("https_just_proxy"); err == nil && b {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
c, err := s.httpsListener.Accept()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
break
|
||||
@@ -141,7 +135,15 @@ func (s *httpServer) Start() error {
|
||||
go s.processHttps(c)
|
||||
}
|
||||
} else {
|
||||
err = httpsSrv.ServeTLS(l, s.pemPath, s.keyPath)
|
||||
if !common.FileExists(s.pemPath) {
|
||||
logs.Error("ssl certFile %s exist", s.keyPath)
|
||||
os.Exit(0)
|
||||
}
|
||||
if !common.FileExists(s.keyPath) {
|
||||
logs.Error("ssl keyFile %s exist", s.keyPath)
|
||||
os.Exit(0)
|
||||
}
|
||||
err = s.httpsServer.ServeTLS(s.httpsListener, s.pemPath, s.keyPath)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
os.Exit(0)
|
||||
@@ -149,20 +151,19 @@ func (s *httpServer) Start() error {
|
||||
}
|
||||
}()
|
||||
}
|
||||
select {
|
||||
case <-s.stop:
|
||||
if httpSrv != nil {
|
||||
httpsSrv.Close()
|
||||
}
|
||||
if httpsSrv != nil {
|
||||
httpsSrv.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpServer) Close() error {
|
||||
s.stop <- true
|
||||
if s.httpsListener != nil {
|
||||
s.httpsListener.Close()
|
||||
}
|
||||
if s.httpsServer != nil {
|
||||
s.httpsServer.Close()
|
||||
}
|
||||
if s.httpServer != nil {
|
||||
s.httpServer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -196,23 +197,17 @@ func (s *httpServer) process(c *conn.Conn, r *http.Request) {
|
||||
if host, err = file.GetCsvDb().GetInfoByHost(r.Host, r); err != nil {
|
||||
logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI)
|
||||
goto end
|
||||
} else if !host.Client.GetConn() { //conn num limit
|
||||
logs.Notice("connections exceed the current client %d limit %d ,now connection num %d", host.Client.Id, host.Client.MaxConn, host.Client.NowConn)
|
||||
}
|
||||
if err := s.CheckFlowAndConnNum(host.Client); err != nil {
|
||||
logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error())
|
||||
c.Close()
|
||||
return
|
||||
} else {
|
||||
logs.Trace("new %s connection,clientId %d,host %s,url %s,remote address %s", r.URL.Scheme, host.Client.Id, r.Host, r.URL, r.RemoteAddr)
|
||||
lastHost = host
|
||||
}
|
||||
logs.Trace("new %s connection,clientId %d,host %s,url %s,remote address %s", r.URL.Scheme, host.Client.Id, r.Host, r.URL, r.RemoteAddr)
|
||||
lastHost = host
|
||||
for {
|
||||
start:
|
||||
if isConn {
|
||||
//流量限制
|
||||
if host.Client.Flow.FlowLimit > 0 && (host.Client.Flow.FlowLimit<<20) < (host.Client.Flow.ExportFlow+host.Client.Flow.InletFlow) {
|
||||
logs.Warn("Traffic exceeded client id %s", host.Client.Id)
|
||||
break
|
||||
}
|
||||
//权限控制
|
||||
if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
|
||||
logs.Warn("auth error", err, r.RemoteAddr)
|
||||
break
|
||||
|
Reference in New Issue
Block a user