mirror of
https://github.com/ehang-io/nps.git
synced 2025-07-02 04:00:42 +00:00
194 lines
6.0 KiB
Go
194 lines
6.0 KiB
Go
package component
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"ehang.io/nps/component/bridge"
|
|
"ehang.io/nps/component/client"
|
|
"ehang.io/nps/lib/cert"
|
|
"ehang.io/nps/lib/pb"
|
|
"github.com/lucas-clemente/quic-go"
|
|
uuid "github.com/satori/go.uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"io/ioutil"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var createCertOnce sync.Once
|
|
|
|
func createCertFile(t *testing.T) (string, string) {
|
|
createCertOnce.Do(func() {
|
|
g := cert.NewX509Generator(pkix.Name{
|
|
Country: []string{"CN"},
|
|
Organization: []string{"Ehang.io"},
|
|
OrganizationalUnit: []string{"nps"},
|
|
Province: []string{"Beijing"},
|
|
CommonName: "nps",
|
|
Locality: []string{"Beijing"},
|
|
})
|
|
cert, key, err := g.CreateRootCa()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "root_cert.pem"), cert, 0600))
|
|
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "root_key.pem"), key, 0600))
|
|
assert.NoError(t, g.InitRootCa(cert, key))
|
|
cert, key, err = g.CreateCert("bridge.nps.ehang.io")
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "bridge_cert.pem"), cert, 0600))
|
|
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "bridge_key.pem"), key, 0600))
|
|
cert, key, err = g.CreateCert("client.nps.ehang.io")
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "client_cert.pem"), cert, 0600))
|
|
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "client_key.pem"), key, 0600))
|
|
})
|
|
return filepath.Join(os.TempDir(), "cert.pem"), filepath.Join(os.TempDir(), "key.pem")
|
|
}
|
|
|
|
func TestTcpConnect(t *testing.T) {
|
|
createCertFile(t)
|
|
bridgeLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
assert.NoError(t, err)
|
|
|
|
buf, err := ioutil.ReadFile(filepath.Join(os.TempDir(), "root_cert.pem"))
|
|
assert.NoError(t, err)
|
|
pool := x509.NewCertPool()
|
|
pool.AppendCertsFromPEM(buf)
|
|
|
|
crt, err := tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "bridge_cert.pem"), filepath.Join(os.TempDir(), "bridge_key.pem"))
|
|
assert.NoError(t, err)
|
|
|
|
bridgeTlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{crt},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: pool,
|
|
}
|
|
|
|
crt, err = tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "client_cert.pem"), filepath.Join(os.TempDir(), "client_key.pem"))
|
|
assert.NoError(t, err)
|
|
|
|
clientConfig := &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{crt},
|
|
}
|
|
go func() {
|
|
assert.NoError(t, bridge.StartTcpBridge(bridgeLn, bridgeTlsConfig, func(s string) bool {
|
|
return true
|
|
}, func(s string) bool {
|
|
sn, err := cert.GetCertSnFromConfig(clientConfig)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, sn, s)
|
|
return true
|
|
}))
|
|
}()
|
|
var c *client.Client
|
|
go func() {
|
|
id, err := cert.GetCertSnFromConfig(clientConfig)
|
|
assert.NoError(t, err)
|
|
|
|
creator := client.TcpTunnelCreator{}
|
|
connId := uuid.NewV1().String()
|
|
controlLn, err := creator.NewMux(bridgeLn.Addr().String(),
|
|
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, clientConfig)
|
|
assert.NoError(t, err)
|
|
dataLn, err := creator.NewMux(bridgeLn.Addr().String(),
|
|
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, clientConfig)
|
|
assert.NoError(t, err)
|
|
c = client.NewClient(controlLn, dataLn)
|
|
c.Run()
|
|
}()
|
|
timeout := time.NewTimer(time.Second * 30)
|
|
ticker := time.NewTicker(time.Millisecond * 100)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if c != nil && c.HasPong() {
|
|
c.Close()
|
|
return
|
|
}
|
|
case <-timeout.C:
|
|
t.Fail()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestQUICConnect(t *testing.T) {
|
|
createCertFile(t)
|
|
bridgePacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
assert.NoError(t, err)
|
|
|
|
buf, err := ioutil.ReadFile(filepath.Join(os.TempDir(), "root_cert.pem"))
|
|
assert.NoError(t, err)
|
|
pool := x509.NewCertPool()
|
|
pool.AppendCertsFromPEM(buf)
|
|
|
|
crt, err := tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "bridge_cert.pem"), filepath.Join(os.TempDir(), "bridge_key.pem"))
|
|
assert.NoError(t, err)
|
|
|
|
bridgeTlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{crt},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: pool,
|
|
NextProtos: []string{"quic-nps"},
|
|
}
|
|
|
|
crt, err = tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "client_cert.pem"), filepath.Join(os.TempDir(), "client_key.pem"))
|
|
assert.NoError(t, err)
|
|
|
|
clientConfig := &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{crt},
|
|
NextProtos: []string{"quic-nps"},
|
|
}
|
|
go func() {
|
|
assert.NoError(t, bridge.StartQUICBridge(bridgePacketConn, bridgeTlsConfig, &quic.Config{
|
|
MaxIncomingStreams: 1000000,
|
|
MaxIncomingUniStreams: 1000000,
|
|
MaxIdleTimeout: time.Minute,
|
|
KeepAlive: true,
|
|
}, func(s string) bool {
|
|
sn, err := cert.GetCertSnFromConfig(clientConfig)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, sn, s)
|
|
return true
|
|
}))
|
|
}()
|
|
var c *client.Client
|
|
go func() {
|
|
id, err := cert.GetCertSnFromConfig(clientConfig)
|
|
assert.NoError(t, err)
|
|
|
|
creator := client.QUICTunnelCreator{}
|
|
connId := uuid.NewV1().String()
|
|
controlLn, err := creator.NewMux(bridgePacketConn.LocalAddr().String(),
|
|
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, clientConfig)
|
|
assert.NoError(t, err)
|
|
dataLn, err := creator.NewMux(bridgePacketConn.LocalAddr().String(),
|
|
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, clientConfig)
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, dataLn)
|
|
assert.NotEmpty(t, controlLn)
|
|
c = client.NewClient(controlLn, dataLn)
|
|
c.Run()
|
|
}()
|
|
timeout := time.NewTimer(time.Second * 30)
|
|
ticker := time.NewTicker(time.Millisecond * 100)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if c != nil && c.HasPong() {
|
|
c.Close()
|
|
return
|
|
}
|
|
case <-timeout.C:
|
|
t.Fail()
|
|
return
|
|
}
|
|
}
|
|
}
|