diff --git a/core/config.go b/core/config.go index 64a02c1..5488aa0 100644 --- a/core/config.go +++ b/core/config.go @@ -4,21 +4,22 @@ package core type Config struct { ConfigName string Description string + ConfigLevel ConfigLevel } type NpsConfigs struct { configs []*Config } -func NewNpsConfigs(name, des string) *NpsConfigs { +func NewNpsConfigs(name, des string, level ConfigLevel) *NpsConfigs { c := &NpsConfigs{} c.configs = make([]*Config, 0) - c.Add(name, des) + c.Add(name, des, level) return c } -func (config *NpsConfigs) Add(name, des string) { - config.configs = append(config.configs, &Config{ConfigName: name, Description: des}) +func (config *NpsConfigs) Add(name, des string, level ConfigLevel) { + config.configs = append(config.configs, &Config{ConfigName: name, Description: des, ConfigLevel: level}) } func (config *NpsConfigs) GetAll() []*Config { diff --git a/core/plugin.go b/core/plugin.go index a3428e1..bb222e3 100644 --- a/core/plugin.go +++ b/core/plugin.go @@ -9,40 +9,51 @@ import ( // Plugin interface, all plugins must implement those functions. type Plugin interface { GetConfigName() *NpsConfigs - GetConfigLevel() ConfigLevel - GetStage() Stage - Start(ctx context.Context, config map[string]string) (context.Context, error) - Run(ctx context.Context, config map[string]string) (context.Context, error) - End(ctx context.Context, config map[string]string) (context.Context, error) + InitConfig(globalConfig, clientConfig, pluginConfig map[string]string) + GetStage() []Stage + Start(ctx context.Context) (context.Context, error) + Run(ctx context.Context) (context.Context, error) + End(ctx context.Context) (context.Context, error) } type NpsPlugin struct { Version string + Configs map[string]string } func (npsPlugin *NpsPlugin) GetConfigName() *NpsConfigs { return nil } -// describe the config level -func (npsPlugin *NpsPlugin) GetConfigLevel() ConfigLevel { - return CONFIG_LEVEL_PLUGIN +func (npsPlugin *NpsPlugin) InitConfig(globalConfig, clientConfig, pluginConfig map[string]string) { + npsPlugin.Configs = make(map[string]string) + for _, cfg := range npsPlugin.GetConfigName().GetAll() { + switch cfg.ConfigLevel { + case CONFIG_LEVEL_PLUGIN: + npsPlugin.Configs[cfg.ConfigName] = pluginConfig[cfg.ConfigName] + case CONFIG_LEVEL_CLIENT: + npsPlugin.Configs[cfg.ConfigName] = clientConfig[cfg.ConfigName] + case CONFIG_LEVEL_GLOBAL: + npsPlugin.Configs[cfg.ConfigName] = globalConfig[cfg.ConfigName] + } + } + return } // describe the stage of the plugin -func (npsPlugin *NpsPlugin) GetStage() Stage { - return STAGE_RUN +func (npsPlugin *NpsPlugin) GetStage() []Stage { + return []Stage{STAGE_RUN} } -func (npsPlugin *NpsPlugin) Start(ctx context.Context, config map[string]string) (context.Context, error) { +func (npsPlugin *NpsPlugin) Start(ctx context.Context) (context.Context, error) { return ctx, nil } -func (npsPlugin *NpsPlugin) Run(ctx context.Context, config map[string]string) (context.Context, error) { +func (npsPlugin *NpsPlugin) Run(ctx context.Context) (context.Context, error) { return ctx, nil } -func (npsPlugin *NpsPlugin) End(ctx context.Context, config map[string]string) (context.Context, error) { +func (npsPlugin *NpsPlugin) End(ctx context.Context) (context.Context, error) { return ctx, nil } @@ -50,6 +61,10 @@ func (npsPlugin *NpsPlugin) GetClientConn(ctx context.Context) net.Conn { return ctx.Value(CLIENT_CONNECTION).(net.Conn) } +func (npsPlugin *NpsPlugin) SetClientConn(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, CLIENT_CONNECTION, conn) +} + func (npsPlugin *NpsPlugin) GetBridge(ctx context.Context) *bridge.Bridge { return ctx.Value(BRIDGE).(*bridge.Bridge) } @@ -59,17 +74,44 @@ func (npsPlugin *NpsPlugin) GetClientId(ctx context.Context) int { } type Plugins struct { - pgs []Plugin + StartPgs []Plugin + RunPgs []Plugin + EndPgs []Plugin + AllPgs []Plugin } func NewPlugins() *Plugins { p := &Plugins{} - p.pgs = make([]Plugin, 0) + p.StartPgs = make([]Plugin, 0) + p.RunPgs = make([]Plugin, 0) + p.EndPgs = make([]Plugin, 0) + p.AllPgs = make([]Plugin, 0) return p } func (pl *Plugins) Add(plugins ...Plugin) { for _, plugin := range plugins { - pl.pgs = append(pl.pgs, plugin) + for _, v := range plugin.GetStage() { + pl.AllPgs = append(pl.RunPgs, plugin) + switch v { + case STAGE_RUN: + pl.RunPgs = append(pl.RunPgs, plugin) + case STAGE_END: + pl.EndPgs = append(pl.EndPgs, plugin) + case STAGE_START: + pl.StartPgs = append(pl.StartPgs, plugin) + } + } } } + +func RunPlugin(ctx context.Context, pgs []Plugin) error { + var err error + for _, pg := range pgs { + ctx, err = pg.Start(ctx) + if err != nil { + return err + } + } + return nil +} diff --git a/core/struct.go b/core/struct.go index 4720386..94297a2 100644 --- a/core/struct.go +++ b/core/struct.go @@ -8,11 +8,7 @@ type Stage uint8 // These constants are meant to describe the stage in which the plugin is running. const ( - STAGE_START_RUN_END Stage = iota - STAGE_START_RUN - STAGE_START_END - STAGE_RUN_END - STAGE_START + STAGE_START Stage = iota STAGE_END STAGE_RUN PROXY_CONNECTION_TYPE = "proxy_target_type" @@ -33,8 +29,7 @@ const ( var ( CLIENT_CONNECTION_NOT_EXIST = errors.New("the client connection is not exist") - BRIDGE_NOT_EXIST = errors.New("the client connection is not exist") + BRIDGE_NOT_EXIST = errors.New("the bridge is not exist") REQUEST_EOF = errors.New("the request has finished") CLIENT_ID_NOT_EXIST = errors.New("the client id is not exist") ) - diff --git a/core/utils.go b/core/utils.go index c3b6865..4d8b058 100644 --- a/core/utils.go +++ b/core/utils.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "github.com/astaxie/beego/logs" "io" "net" + "strings" ) func CopyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { @@ -69,3 +71,28 @@ func GetLenBytes(buf []byte) (b []byte, err error) { b = raw.Bytes() return } + +func NewTcpListenerAndProcess(addr string, f func(c net.Conn), listener *net.Listener) error { + var err error + *listener, err = net.Listen("tcp", addr) + if err != nil { + return err + } + Accept(*listener, f) + return nil +} + +func Accept(l net.Listener, f func(c net.Conn)) { + for { + c, err := l.Accept() + if err != nil { + if strings.Contains(err.Error(), "use of closed network connection") { + break + } + logs.Warn(err) + continue + } + go f(c) + } +} + diff --git a/server/common/common_inet_proxy_handle.go b/server/common/common_inet_proxy_handle.go index acda430..77cfbbd 100644 --- a/server/common/common_inet_proxy_handle.go +++ b/server/common/common_inet_proxy_handle.go @@ -13,10 +13,10 @@ type Proxy struct { } func (proxy *Proxy) GetConfigName() *core.NpsConfigs { - return core.NewNpsConfigs("socks5_proxy", "proxy to inet") + return core.NewNpsConfigs("socks5_proxy", "proxy to inet", core.CONFIG_LEVEL_PLUGIN) } -func (proxy *Proxy) Run(ctx context.Context, config map[string]string) (context.Context, error) { +func (proxy *Proxy) Run(ctx context.Context) (context.Context, error) { proxy.ctx = ctx proxy.clientConn = proxy.GetClientConn(ctx) clientId := proxy.GetClientId(ctx) @@ -27,7 +27,13 @@ func (proxy *Proxy) Run(ctx context.Context, config map[string]string) (context. return ctx, err } + // send connection information to the npc + if _, err := core.SendInfo(severConn, nil); err != nil { + return ctx, err + } + + // data exchange go core.CopyBuffer(severConn, proxy.clientConn) core.CopyBuffer(proxy.clientConn, severConn) - return ctx, nil + return ctx, core.REQUEST_EOF } diff --git a/server/socks5/socks5_check_access_handle.go b/server/socks5/socks5_check_access_handle.go index a5c607c..bea4028 100644 --- a/server/socks5/socks5_check_access_handle.go +++ b/server/socks5/socks5_check_access_handle.go @@ -17,16 +17,16 @@ type CheckAccess struct { } func (check *CheckAccess) GetConfigName() *core.NpsConfigs { - c := core.NewNpsConfigs("socks5_simple_access_check", "need check the permission simply") - c.Add("socks5_simple_access_username", "simple auth username") - c.Add("socks5_simple_access_password", "simple auth password") + c := core.NewNpsConfigs("socks5_simple_access_check", "need check the permission simply", core.CONFIG_LEVEL_PLUGIN) + c.Add("socks5_simple_access_username", "simple auth username", core.CONFIG_LEVEL_PLUGIN) + c.Add("socks5_simple_access_password", "simple auth password", core.CONFIG_LEVEL_PLUGIN) return c } -func (check *CheckAccess) Run(ctx context.Context, config map[string]string) (context.Context, error) { +func (check *CheckAccess) Run(ctx context.Context) (context.Context, error) { check.clientConn = check.GetClientConn(ctx) - check.configUsername = config["socks5_access_username"] - check.configPassword = config["socks5_access_password"] + check.configUsername = check.Configs["socks5_access_username"] + check.configPassword = check.Configs["socks5_access_password"] return ctx, nil } diff --git a/server/socks5/socks5_handshake_handle.go b/server/socks5/socks5_handshake_handle.go index 05f841f..1e91cff 100644 --- a/server/socks5/socks5_handshake_handle.go +++ b/server/socks5/socks5_handshake_handle.go @@ -12,7 +12,7 @@ type Handshake struct { core.NpsPlugin } -func (handshake *Handshake) Run(ctx context.Context, config map[string]string) (context.Context, error) { +func (handshake *Handshake) Run(ctx context.Context) (context.Context, error) { clientConn := handshake.GetClientConn(ctx) buf := make([]byte, 2) if _, err := io.ReadFull(clientConn, buf); err != nil { diff --git a/server/socks5/socks5_read_access_handle.go b/server/socks5/socks5_read_access_handle.go index dc805b3..5cebf96 100644 --- a/server/socks5/socks5_read_access_handle.go +++ b/server/socks5/socks5_read_access_handle.go @@ -22,12 +22,12 @@ type Access struct { } func (access *Access) GetConfigName() *core.NpsConfigs { - return core.NewNpsConfigs("socks5_check_access_check", "need check the permission simply") + return core.NewNpsConfigs("socks5_check_access_check", "need check the permission simply",core.CONFIG_LEVEL_PLUGIN) } -func (access *Access) Run(ctx context.Context, config map[string]string) (context.Context, error) { +func (access *Access) Run(ctx context.Context) (context.Context, error) { access.clientConn = access.GetClientConn(ctx) - if config["socks5_check_access"] != "true" { + if access.Configs["socks5_check_access"] != "true" { return ctx, access.sendAccessMsgToClient(UserNoAuth) } // need auth diff --git a/server/socks5/socks5_read_request_handle.go b/server/socks5/socks5_read_request_handle.go index af5e295..90d5228 100644 --- a/server/socks5/socks5_read_request_handle.go +++ b/server/socks5/socks5_read_request_handle.go @@ -32,7 +32,7 @@ const ( addrTypeNotSupported = 8 ) -func (request *Request) Run(ctx context.Context, config map[string]string) (context.Context, error) { +func (request *Request) Run(ctx context.Context) (context.Context, error) { request.clientConn = request.GetClientConn(ctx) request.ctx = ctx diff --git a/server/socks5/socks5_server.go b/server/socks5/socks5_server.go index b753b38..504618f 100644 --- a/server/socks5/socks5_server.go +++ b/server/socks5/socks5_server.go @@ -1,8 +1,9 @@ package socks5 import ( + "context" + "fmt" "github.com/cnlh/nps/core" - "github.com/cnlh/nps/lib/conn" "github.com/cnlh/nps/server/common" "net" "strconv" @@ -29,8 +30,28 @@ func NewS5Server(globalConfig, clientConfig, pluginConfig map[string]string) *S5 return s5 } -func (s5 *S5Server) Start() error { - return conn.NewTcpListenerAndProcess(s5.ServerIp+":"+strconv.Itoa(s5.ServerPort), func(c net.Conn) { +func (s5 *S5Server) Start(ctx context.Context) error { + // init config of plugin + for _, pg := range s5.plugins.AllPgs { + pg.InitConfig(s5.globalConfig, s5.clientConfig, s5.pluginConfig) + } + // run the plugin contains start + if core.RunPlugin(ctx, s5.plugins.StartPgs) != nil { + return nil + } + return core.NewTcpListenerAndProcess(s5.ServerIp+":"+strconv.Itoa(s5.ServerPort), func(c net.Conn) { + // init ctx value clientConn + ctx = context.WithValue(ctx, core.CLIENT_CONNECTION, c) + // start run the plugin run + if err := core.RunPlugin(ctx, s5.plugins.RunPgs); err != nil { + fmt.Println(err) + return + } + // start run the plugin end + if err := core.RunPlugin(ctx, s5.plugins.EndPgs); err != nil { + fmt.Println(err) + return + } }, &s5.listener) }