mirror of
https://github.com/chai2010/advanced-go-programming-book.git
synced 2025-05-25 05:02:23 +00:00
229 lines
6.4 KiB
Go
Executable File
229 lines
6.4 KiB
Go
Executable File
package runtime_test
|
|
|
|
import (
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"context"
|
|
"github.com/golang/protobuf/proto"
|
|
pb "github.com/grpc-ecosystem/grpc-gateway/examples/proto/examplepb"
|
|
"github.com/grpc-ecosystem/grpc-gateway/runtime"
|
|
"github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
func TestForwardResponseStream(t *testing.T) {
|
|
type msg struct {
|
|
pb proto.Message
|
|
err error
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
msgs []msg
|
|
statusCode int
|
|
}{{
|
|
name: "encoding",
|
|
msgs: []msg{
|
|
{&pb.SimpleMessage{Id: "One"}, nil},
|
|
{&pb.SimpleMessage{Id: "Two"}, nil},
|
|
},
|
|
statusCode: http.StatusOK,
|
|
}, {
|
|
name: "empty",
|
|
statusCode: http.StatusOK,
|
|
}, {
|
|
name: "error",
|
|
msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}},
|
|
statusCode: http.StatusBadRequest,
|
|
}, {
|
|
name: "stream_error",
|
|
msgs: []msg{
|
|
{&pb.SimpleMessage{Id: "One"}, nil},
|
|
{nil, grpc.Errorf(codes.OutOfRange, "400")},
|
|
},
|
|
statusCode: http.StatusOK,
|
|
}}
|
|
|
|
newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
|
|
var count int
|
|
return func() (proto.Message, error) {
|
|
if count == len(msgs) {
|
|
return nil, io.EOF
|
|
} else if count > len(msgs) {
|
|
t.Errorf("recv() called %d times for %d messages", count, len(msgs))
|
|
}
|
|
count++
|
|
msg := msgs[count-1]
|
|
return msg.pb, msg.err
|
|
}
|
|
}
|
|
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
|
|
marshaler := &runtime.JSONPb{}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
recv := newTestRecv(t, tt.msgs)
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
|
|
|
|
w := resp.Result()
|
|
if w.StatusCode != tt.statusCode {
|
|
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
|
|
}
|
|
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
|
|
t.Errorf("ForwardResponseStream missing header chunked")
|
|
}
|
|
body, err := ioutil.ReadAll(w.Body)
|
|
if err != nil {
|
|
t.Errorf("Failed to read response body with %v", err)
|
|
}
|
|
w.Body.Close()
|
|
|
|
var want []byte
|
|
for i, msg := range tt.msgs {
|
|
if msg.err != nil {
|
|
if i == 0 {
|
|
// Skip non-stream errors
|
|
t.Skip("checking error encodings")
|
|
}
|
|
st, _ := status.FromError(msg.err)
|
|
httpCode := runtime.HTTPStatusFromCode(st.Code())
|
|
b, err := marshaler.Marshal(map[string]proto.Message{
|
|
"error": &internal.StreamError{
|
|
GrpcCode: int32(st.Code()),
|
|
HttpCode: int32(httpCode),
|
|
Message: st.Message(),
|
|
HttpStatus: http.StatusText(httpCode),
|
|
Details: st.Proto().GetDetails(),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Errorf("marshaler.Marshal() failed %v", err)
|
|
}
|
|
errBytes := body[len(want):]
|
|
if string(errBytes) != string(b) {
|
|
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", errBytes, b)
|
|
}
|
|
|
|
return
|
|
}
|
|
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
|
|
if err != nil {
|
|
t.Errorf("marshaler.Marshal() failed %v", err)
|
|
}
|
|
want = append(want, b...)
|
|
want = append(want, marshaler.Delimiter()...)
|
|
}
|
|
|
|
if string(body) != string(want) {
|
|
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// A custom marshaler implementation, that doesn't implement the delimited interface
|
|
type CustomMarshaler struct {
|
|
m *runtime.JSONPb
|
|
}
|
|
|
|
func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) }
|
|
func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) }
|
|
func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) }
|
|
func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) }
|
|
func (c *CustomMarshaler) ContentType() string { return c.m.ContentType() }
|
|
|
|
func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
|
|
type msg struct {
|
|
pb proto.Message
|
|
err error
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
msgs []msg
|
|
statusCode int
|
|
}{{
|
|
name: "encoding",
|
|
msgs: []msg{
|
|
{&pb.SimpleMessage{Id: "One"}, nil},
|
|
{&pb.SimpleMessage{Id: "Two"}, nil},
|
|
},
|
|
statusCode: http.StatusOK,
|
|
}, {
|
|
name: "empty",
|
|
statusCode: http.StatusOK,
|
|
}, {
|
|
name: "error",
|
|
msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}},
|
|
statusCode: http.StatusBadRequest,
|
|
}, {
|
|
name: "stream_error",
|
|
msgs: []msg{
|
|
{&pb.SimpleMessage{Id: "One"}, nil},
|
|
{nil, grpc.Errorf(codes.OutOfRange, "400")},
|
|
},
|
|
statusCode: http.StatusOK,
|
|
}}
|
|
|
|
newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
|
|
var count int
|
|
return func() (proto.Message, error) {
|
|
if count == len(msgs) {
|
|
return nil, io.EOF
|
|
} else if count > len(msgs) {
|
|
t.Errorf("recv() called %d times for %d messages", count, len(msgs))
|
|
}
|
|
count++
|
|
msg := msgs[count-1]
|
|
return msg.pb, msg.err
|
|
}
|
|
}
|
|
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
|
|
marshaler := &CustomMarshaler{&runtime.JSONPb{}}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
recv := newTestRecv(t, tt.msgs)
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
|
|
|
|
w := resp.Result()
|
|
if w.StatusCode != tt.statusCode {
|
|
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
|
|
}
|
|
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
|
|
t.Errorf("ForwardResponseStream missing header chunked")
|
|
}
|
|
body, err := ioutil.ReadAll(w.Body)
|
|
if err != nil {
|
|
t.Errorf("Failed to read response body with %v", err)
|
|
}
|
|
w.Body.Close()
|
|
|
|
var want []byte
|
|
for _, msg := range tt.msgs {
|
|
if msg.err != nil {
|
|
t.Skip("checking erorr encodings")
|
|
}
|
|
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
|
|
if err != nil {
|
|
t.Errorf("marshaler.Marshal() failed %v", err)
|
|
}
|
|
want = append(want, b...)
|
|
want = append(want, "\n"...)
|
|
}
|
|
|
|
if string(body) != string(want) {
|
|
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
|
|
}
|
|
})
|
|
}
|
|
}
|