diff --git a/cmd/protoc-gen-goclay/genhandler/tpl_httpclient.go b/cmd/protoc-gen-goclay/genhandler/tpl_httpclient.go index 7d18e4b..2909f0b 100644 --- a/cmd/protoc-gen-goclay/genhandler/tpl_httpclient.go +++ b/cmd/protoc-gen-goclay/genhandler/tpl_httpclient.go @@ -45,6 +45,7 @@ func (c *{{ $svc.GetName }}_httpClient) {{ $m.GetName }}(ctx {{ pkg "context" }} if err != nil { return nil, {{ pkg "errors" }}Wrap(err, "can't initiate HTTP request") } + req = req.WithContext(ctx) req.Header.Add("Accept", m.ContentType()) diff --git a/integration/Makefile b/integration/Makefile index 0318804..7b000ac 100644 --- a/integration/Makefile +++ b/integration/Makefile @@ -8,6 +8,8 @@ VGO_BIN:=$(GOBIN)/vgo GEN_CLAY_BIN:=$(CURDIR)/bin/protoc-gen-goclay export GEN_CLAY_BIN +all: clean build test + # install vgo $(VGO_BIN): ifeq (${VGO_VERSION},master) @@ -55,4 +57,4 @@ test: $(VGO_BIN) build fi clean: - @find */ -type f -name Makefile -execdir sh -c "make clean; echo ;" \; \ No newline at end of file + @find */ -type f -name Makefile -execdir sh -c "make clean; echo ;" \; diff --git a/integration/http_headers_passthru/pb/strings/strings_test.go b/integration/http_headers_passthru/pb/strings/strings_test.go index e933c94..ae14e96 100644 --- a/integration/http_headers_passthru/pb/strings/strings_test.go +++ b/integration/http_headers_passthru/pb/strings/strings_test.go @@ -105,6 +105,64 @@ func TestHTTPHeadersPass_genClient(t *testing.T) { so.True(calledFunc) } +// TestHTTPHeadersPass_genClient_outgoingContext tests that generated HTTP client +// passes headers from grpc.ToOutgoingContext to the request by default. +func TestHTTPHeadersPass_genClient_outgoingContext(t *testing.T) { + so := assert.New(t) + impl, ts := getTestSvc() + defer ts.Close() + + tc := []a{ + a{ + Name: "User-Agent", + Values: []string{"Go-http-client/1.1"}, + }, + a{ + Name: "Accept", + Values: []string{"application/json"}, + }, + } + pt := []a{ + a{ + Name: "X-Test-Passthrough", + Values: []string{"v1", "Value2", "3"}, + }, + } + + calledFunc := false + + impl.f = func(ctx context.Context, req *String) (*String, error) { + calledFunc = true + md, ok := metadata.FromIncomingContext(ctx) + so.True(ok) + + for _, c := range tc { + got := md.Get(strings.ToLower(c.Name)) + so.EqualValues(c.Values, got) + } + + for _, c := range pt { + got := md.Get(strings.ToLower(c.Name)) + so.EqualValues(c.Values, got) + } + + return &String{}, nil + } + + ctx := context.Background() + + for _, c := range pt { + for i := range c.Values { + ctx = metadata.AppendToOutgoingContext(ctx, c.Name, c.Values[i]) + } + } + + cli := NewStringsHTTPClient(http.DefaultClient, ts.URL) + _, err := cli.ToLower(ctx, &String{}) + so.Nil(err) + so.True(calledFunc) +} + func getTestSvc() (*StringsImplementation, *httptest.Server) { mux := http.NewServeMux() impl := NewStrings() diff --git a/transport/httpclient/requestmw.go b/transport/httpclient/requestmw.go index b19a35d..228a07e 100644 --- a/transport/httpclient/requestmw.go +++ b/transport/httpclient/requestmw.go @@ -81,7 +81,7 @@ type RequestMutator func(*http.Request) (*http.Request, error) type ResponseMutator func(*http.Response) (*http.Response, error) // DefaultRequestMutators are used for every outgoing request. -var DefaultRequestMutators = []RequestMutator{} +var DefaultRequestMutators = []RequestMutator{clientReqHeadersFromMD()} // DefaultResponseMutators are used for every received response. var DefaultResponseMutators = []ResponseMutator{} @@ -95,3 +95,25 @@ func clientRspHeaderCopier(md *metadata.MD) ResponseMutator { return rsp, nil } } + +// clientReqHeadersFromMD pushes metadata from OutgoingContext to the +// request headers. +func clientReqHeadersFromMD() RequestMutator { + return func(req *http.Request) (*http.Request, error) { + fmt.Println("called") + ctxmd, ok := metadata.FromOutgoingContext(req.Context()) + if !ok { + return req, nil + } + + for k := range ctxmd { + vv := ctxmd.Get(k) + for i := range vv { + req.Header.Add(k, vv[i]) + } + } + + return req, nil + } + +}