nps/component/controller/controller_test.go
2022-01-23 17:30:38 +08:00

150 lines
4.7 KiB
Go

package controller
import (
"bytes"
"crypto/x509/pkix"
"ehang.io/nps/core/action"
"ehang.io/nps/core/handler"
"ehang.io/nps/core/process"
"ehang.io/nps/core/rule"
"ehang.io/nps/core/server"
"ehang.io/nps/db"
"ehang.io/nps/lib/cert"
"encoding/json"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
func TestController(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:3500")
assert.NoError(t, err)
err = os.Remove(filepath.Join(os.TempDir(), "test_control.db"))
d := db.NewSqliteDb(filepath.Join(os.TempDir(), "test_control.db"))
err = d.Init()
assert.NoError(t, err)
assert.NoError(t, d.SetConfig("admin_user", "admin"))
assert.NoError(t, d.SetConfig("admin_pass", "pass"))
cg := cert.NewX509Generator(pkix.Name{
Country: []string{"cn"},
Organization: []string{"ehang"},
OrganizationalUnit: []string{"nps"},
Province: []string{"beijing"},
CommonName: "nps",
Locality: []string{"beijing"},
})
assert.NoError(t, err)
cert, key, err := cg.CreateRootCa()
assert.NoError(t, err)
go func() {
err = StartController(ln, d, cert, key, "./web/static/", "./web/views/")
assert.NoError(t, err)
}()
resp, err := doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/login"), "POST", `{"username": "admin","password": "pass"}`)
assert.NoError(t, err)
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
for i := 0; i < 18; i++ {
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/cert"), "POST", fmt.Sprintf(`{"status":1,"name":"name_%d","cert_type": "client"}`, i))
assert.NoError(t, err)
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
}
resp, err = doRequest(fmt.Sprintf("http://%s%s?page=%d&pageSize=%d", ln.Addr().String(), "/v1/cert/page", 4, 5), "GET", ``)
assert.NoError(t, err)
now := 2
var lastUuid string
assert.Equal(t, len(gjson.Parse(resp).Get("result.items").Array()), 3)
for _, v := range gjson.Parse(resp).Get("result.items").Array() {
assert.Equal(t, v.Get("name").String(), fmt.Sprintf(`name_%d`, now))
lastUuid = v.Get("uuid").String()
now--
}
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/cert"), "DELETE", fmt.Sprintf(`{"uuid":"%s"}`, lastUuid))
assert.NoError(t, err)
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
s := &server.TcpServer{ServerAddr: "127.0.0.1:0"}
h := &handler.DefaultHandler{}
p := &process.DefaultProcess{}
a := &action.LocalAction{}
rj := &rule.JsonRule{
Name: "test",
Status: 1,
Server: rule.JsonData{ObjType: s.GetName(), ObjData: getJson(t, s)},
Handler: rule.JsonData{ObjType: h.GetName(), ObjData: getJson(t, h)},
Process: rule.JsonData{ObjType: p.GetName(), ObjData: getJson(t, p)},
Action: rule.JsonData{ObjType: a.GetName(), ObjData: getJson(t, a)},
Limiters: []rule.JsonData{},
}
js := getJson(t, rj)
for i := 0; i < 18; i++ {
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/rule"), "POST", js)
assert.NoError(t, err)
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
}
resp, err = doRequest(fmt.Sprintf("http://%s%s?page=%d&pageSize=%d", ln.Addr().String(), "/v1/rule/page", 1, 10), "GET", ``)
assert.NoError(t, err)
assert.Equal(t, int(gjson.Parse(resp).Get("result.total").Int()), 18)
uuid := gjson.Parse(resp).Get("result.items").Array()[0].Get("uuid").String()
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/rule"), "GET", fmt.Sprintf(`{"uuid": "%s"}`, uuid))
assert.NoError(t, err)
assert.Equal(t, gjson.Parse(resp).Get("result.uuid").String(), uuid)
rj.Uuid = uuid
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/rule"), "PUT", getJson(t, rj))
assert.NoError(t, err)
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
time.Sleep(time.Minute * 600)
}
func getJson(t *testing.T, i interface{}) string {
b, err := json.Marshal(i)
assert.NoError(t, err)
assert.NotEmpty(t, string(b))
return string(b)
}
var client *http.Client
var once sync.Once
var cookies []*http.Cookie
func doRequest(url string, method string, body string) (string, error) {
once.Do(func() {
client = &http.Client{}
})
payload := bytes.NewBufferString(body)
req, err := http.NewRequest(method, url, payload)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
for _, c := range cookies {
req.AddCookie(c)
}
res, err := client.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
b, err := ioutil.ReadAll(res.Body)
if len(res.Cookies()) > 0 {
cookies = res.Cookies()
}
if res.StatusCode != 200 {
return string(b), errors.New("bad doRequest")
}
return string(b), err
}