nps/bridge/bridge.go

407 lines
9.9 KiB
Go
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package bridge
import (
"encoding/binary"
"errors"
"fmt"
"github.com/cnlh/nps/lib/common"
"github.com/cnlh/nps/lib/conn"
"github.com/cnlh/nps/lib/crypt"
"github.com/cnlh/nps/lib/file"
"github.com/cnlh/nps/lib/mux"
"github.com/cnlh/nps/lib/version"
"github.com/cnlh/nps/server/tool"
"github.com/cnlh/nps/vender/github.com/astaxie/beego"
"github.com/cnlh/nps/vender/github.com/astaxie/beego/logs"
"github.com/cnlh/nps/vender/github.com/xtaci/kcp"
"net"
"strconv"
"sync"
"time"
)
type Client struct {
tunnel *mux.Mux
signal *conn.Conn
sync.RWMutex
}
func NewClient(t *mux.Mux, s *conn.Conn) *Client {
return &Client{
signal: s,
tunnel: t,
}
}
type Bridge struct {
TunnelPort int //通信隧道端口
tcpListener *net.TCPListener //server端监听
kcpListener *kcp.Listener //server端监听
Client map[int]*Client
tunnelType string //bridge type kcp or tcp
OpenTask chan *file.Tunnel
CloseClient chan int
SecretChan chan *conn.Secret
clientLock sync.RWMutex
Register map[string]time.Time
registerLock sync.RWMutex
ipVerify bool
runList map[int]interface{}
}
func NewTunnel(tunnelPort int, tunnelType string, ipVerify bool, runList map[int]interface{}) *Bridge {
t := new(Bridge)
t.TunnelPort = tunnelPort
t.Client = make(map[int]*Client)
t.tunnelType = tunnelType
t.OpenTask = make(chan *file.Tunnel)
t.CloseClient = make(chan int)
t.Register = make(map[string]time.Time)
t.ipVerify = ipVerify
t.runList = runList
t.SecretChan = make(chan *conn.Secret)
return t
}
func (s *Bridge) StartTunnel() error {
var err error
if s.tunnelType == "kcp" {
s.kcpListener, err = kcp.ListenWithOptions(":"+strconv.Itoa(s.TunnelPort), nil, 150, 3)
if err != nil {
return err
}
go func() {
for {
c, err := s.kcpListener.AcceptKCP()
conn.SetUdpSession(c)
if err != nil {
logs.Warn(err)
continue
}
go s.cliProcess(conn.NewConn(c))
}
}()
} else {
s.tcpListener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""})
if err != nil {
return err
}
go func() {
for {
c, err := s.tcpListener.Accept()
if err != nil {
logs.Warn(err)
continue
}
go s.cliProcess(conn.NewConn(c))
}
}()
}
return nil
}
//验证失败返回错误验证flag并且关闭连接
func (s *Bridge) verifyError(c *conn.Conn) {
c.Write([]byte(common.VERIFY_EER))
c.Conn.Close()
}
func (s *Bridge) verifySuccess(c *conn.Conn) {
c.Write([]byte(common.VERIFY_SUCCESS))
}
func (s *Bridge) cliProcess(c *conn.Conn) {
//version check
if b, err := c.GetShortContent(32); err != nil || string(b) != crypt.Md5(version.GetVersion()) {
logs.Info("The client %s version does not match", c.Conn.RemoteAddr())
c.Close()
return
}
c.Write([]byte(crypt.Md5(version.GetVersion())))
c.SetReadDeadline(5, s.tunnelType)
var buf []byte
var err error
if buf, err = c.GetShortContent(32); err != nil {
c.Close()
return
}
//验证
id, err := file.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String())
if err != nil {
logs.Info("Current client connection validation error, close this client:", c.Conn.RemoteAddr())
s.verifyError(c)
return
} else {
s.verifySuccess(c)
}
//做一个判断 添加到对应的channel里面以供使用
if flag, err := c.ReadFlag(); err == nil {
s.typeDeal(flag, c, id)
} else {
logs.Warn(err, flag)
}
return
}
func (s *Bridge) DelClient(id int, isOther bool) {
s.clientLock.Lock()
defer s.clientLock.Unlock()
if v, ok := s.Client[id]; ok {
if c, err := file.GetCsvDb().GetClient(id); err == nil && c.NoStore {
s.CloseClient <- c.Id
}
v.signal.Close()
delete(s.Client, id)
}
}
//use different
func (s *Bridge) typeDeal(typeVal string, c *conn.Conn, id int) {
switch typeVal {
case common.WORK_MAIN:
//the vKey connect by another ,close the client of before
s.clientLock.Lock()
if v, ok := s.Client[id]; ok {
s.clientLock.Unlock()
if v.signal != nil {
v.signal.WriteClose()
}
v.Lock()
v.signal = c
v.Unlock()
} else {
s.Client[id] = NewClient(nil, c)
s.clientLock.Unlock()
}
go func(id int) {
binary.Read(c, binary.LittleEndian, true)
s.DelClient(id, false)
}(id)
logs.Info("clientId %d connection succeeded, address:%s ", id, c.Conn.RemoteAddr())
case common.WORK_CHAN:
s.clientLock.Lock()
if v, ok := s.Client[id]; ok {
s.clientLock.Unlock()
v.Lock()
v.tunnel = mux.NewMux(c.Conn)
v.Unlock()
} else {
s.Client[id] = NewClient(mux.NewMux(c.Conn), nil)
s.clientLock.Unlock()
}
case common.WORK_CONFIG:
go s.getConfig(c)
case common.WORK_REGISTER:
go s.register(c)
case common.WORK_SECRET:
if b, err := c.GetShortContent(32); err == nil {
s.SecretChan <- conn.NewSecret(string(b), c)
}
case common.WORK_P2P:
//read md5 secret
if b, err := c.GetShortContent(32); err != nil {
return
} else if t := file.GetCsvDb().GetTaskByMd5Password(string(b)); t == nil {
return
} else {
s.clientLock.Lock()
if v, ok := s.Client[t.Client.Id]; !ok {
s.clientLock.Unlock()
return
} else {
s.clientLock.Unlock()
//向密钥对应的客户端发送与服务端udp建立连接信息地址密钥
v.signal.Write([]byte(common.NEW_UDP_CONN))
svrAddr := beego.AppConfig.String("serverIp") + ":" + beego.AppConfig.String("p2pPort")
if err != nil {
logs.Warn("get local udp addr error")
return
}
v.signal.WriteLenContent([]byte(svrAddr))
v.signal.WriteLenContent(b)
//向该请求者发送建立连接请求,服务器地址
c.WriteLenContent([]byte(svrAddr))
}
}
}
c.SetAlive(s.tunnelType)
return
}
//register ip
func (s *Bridge) register(c *conn.Conn) {
var hour int32
if err := binary.Read(c, binary.LittleEndian, &hour); err == nil {
s.registerLock.Lock()
s.Register[common.GetIpByAddr(c.Conn.RemoteAddr().String())] = time.Now().Add(time.Hour * time.Duration(hour))
s.registerLock.Unlock()
}
}
func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link, linkAddr string) (target net.Conn, err error) {
s.clientLock.Lock()
if v, ok := s.Client[clientId]; ok {
s.clientLock.Unlock()
if s.ipVerify {
s.registerLock.Lock()
ip := common.GetIpByAddr(linkAddr)
if v, ok := s.Register[ip]; !ok {
s.registerLock.Unlock()
return nil, errors.New(fmt.Sprintf("The ip %s is not in the validation list", ip))
} else {
if !v.After(time.Now()) {
return nil, errors.New(fmt.Sprintf("The validity of the ip %s has expired", ip))
}
}
s.registerLock.Unlock()
}
if v.tunnel == nil {
err = errors.New("the client connect error")
return
}
if target, err = v.tunnel.NewConn(); err != nil {
return
}
if _, err = conn.NewConn(target).SendLinkInfo(link); err != nil {
logs.Warn("new connect error ,the target %s refuse to connect", link.Host)
return
}
} else {
s.clientLock.Unlock()
err = errors.New(fmt.Sprintf("the client %d is not connect", clientId))
}
return
}
//get config and add task from client config
func (s *Bridge) getConfig(c *conn.Conn) {
var client *file.Client
var fail bool
for {
flag, err := c.ReadFlag()
if err != nil {
break
}
switch flag {
case common.WORK_STATUS:
if b, err := c.GetShortContent(32); err != nil {
break
} else {
logs.Warn(string(b))
var str string
id, err := file.GetCsvDb().GetClientIdByVkey(string(b))
if err != nil {
break
}
for _, v := range file.GetCsvDb().Hosts {
if v.Client.Id == id {
str += v.Remark + common.CONN_DATA_SEQ
}
}
for _, v := range file.GetCsvDb().Tasks {
if _, ok := s.runList[v.Id]; ok && v.Client.Id == id {
str += v.Remark + common.CONN_DATA_SEQ
}
}
binary.Write(c, binary.LittleEndian, int32(len([]byte(str))))
binary.Write(c, binary.LittleEndian, []byte(str))
}
case common.NEW_CONF:
var err error
if client, err = c.GetConfigInfo(); err != nil {
fail = true
c.WriteAddFail()
break
} else {
if err = file.GetCsvDb().NewClient(client); err != nil {
fail = true
c.WriteAddFail()
break
}
c.WriteAddOk()
c.Write([]byte(client.VerifyKey))
}
case common.NEW_HOST:
if h, err := c.GetHostInfo(); err != nil {
fail = true
c.WriteAddFail()
break
} else if file.GetCsvDb().IsHostExist(h) {
fail = true
c.WriteAddFail()
break
} else {
h.Client = client
file.GetCsvDb().NewHost(h)
c.WriteAddOk()
}
case common.NEW_TASK:
if t, err := c.GetTaskInfo(); err != nil {
fail = true
c.WriteAddFail()
break
} else {
ports := common.GetPorts(t.Ports)
targets := common.GetPorts(t.Target)
if len(ports) > 1 && (t.Mode == "tcp" || t.Mode == "udp") && (len(ports) != len(targets)) {
fail = true
c.WriteAddFail()
break
} else if t.Mode == "secret" {
ports = append(ports, 0)
}
if len(ports) == 0 {
fail = true
c.WriteAddFail()
break
}
for i := 0; i < len(ports); i++ {
tl := new(file.Tunnel)
tl.Mode = t.Mode
tl.Port = ports[i]
if len(ports) == 1 {
tl.Target = t.Target
tl.Remark = t.Remark
} else {
tl.Remark = t.Remark + "_" + strconv.Itoa(tl.Port)
if t.TargetAddr != "" {
tl.Target = t.TargetAddr + ":" + strconv.Itoa(targets[i])
} else {
tl.Target = strconv.Itoa(targets[i])
}
}
tl.Id = file.GetCsvDb().GetTaskId()
tl.Status = true
tl.Flow = new(file.Flow)
tl.NoStore = true
tl.Client = client
tl.Password = t.Password
if err := file.GetCsvDb().NewTask(tl); err != nil {
logs.Notice("Add task error ", err.Error())
fail = true
c.WriteAddFail()
break
}
if b := tool.TestServerPort(tl.Port, tl.Mode); !b && t.Mode != "secret" {
fail = true
c.WriteAddFail()
break
} else {
s.OpenTask <- tl
}
c.WriteAddOk()
}
}
}
}
if fail && client != nil {
s.CloseClient <- client.Id
}
c.Close()
}