From 677893bc47b5ac22b9f60b280093a3c75700dc68 Mon Sep 17 00:00:00 2001 From: Dimitrij Drus Date: Wed, 14 Aug 2024 01:31:27 +0200 Subject: [PATCH] initial, not yet working implementation preserving the work done --- middleware.go | 323 +++++++++++++++++++++++++++++++++++++++++++++ middleware_test.go | 98 ++++++++++++++ transport.go | 24 ++++ transport_test.go | 74 +++++++++++ 4 files changed, 519 insertions(+) create mode 100644 middleware.go create mode 100644 middleware_test.go create mode 100644 transport.go create mode 100644 transport_test.go diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..43e6a3f --- /dev/null +++ b/middleware.go @@ -0,0 +1,323 @@ +package httpsig + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/dunglas/httpsfv" + "github.com/felixge/httpsnoop" +) + +type HandlerOption func(*handler) + +type Logger interface { + Logf(ctx context.Context, msg string, args ...any) +} + +type noopLogger struct{} + +func (noopLogger) Logf(_ context.Context, _ string, _ ...any) {} + +func WithLogger(logger Logger) HandlerOption { + return func(h *handler) { + if logger != nil { + h.l = logger + } + } +} + +func WithErrorCode(code int) HandlerOption { + return func(h *handler) { + h.ec = code + } +} + +func WithSignedResponses(signer *signer) HandlerOption { + return func(h *handler) { + if signer != nil { + h.s = append(h.s, signer) + } + } +} + +func WithSignedResponseNegotiation(kr KeyResolver) HandlerOption { + return func(h *handler) { + if kr != nil { + h.kr = kr + } + } +} + +type handler struct { + v Verifier + s compositeSigner + kr KeyResolver + l Logger + ec int +} + +func (h *handler) wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if err := h.v.Verify(MessageFromRequest(req)); err != nil { + var errNoApplicable *NoApplicableSignatureError + + if errors.As(err, &errNoApplicable) { + h.l.Logf(req.Context(), "No applicable http signature present.") + + errNoApplicable.Negotiate(rw.Header()) + } else { + h.l.Logf(req.Context(), "Failed verifying http signature: %v.", err) + } + + rw.WriteHeader(h.ec) + + return + } + + signer, done := h.signerFor(req) + if done { + rw.WriteHeader(h.ec) + + return + } + + next.ServeHTTP(newResponseWriterWrapper(h.l, signer, rw, req), req) + }) +} + +func (h *handler) signerFor(req *http.Request) (Signer, bool) { + if h.kr == nil { + return h.s, false + } + + sigReqs, err := getSignatureRequirements(req.Header.Values(headerAcceptSignature)) + if err != nil { + h.l.Logf(req.Context(), "Failed negotiating http signature for response: %v.", err) + + return nil, true + } + + signer := make(compositeSigner, 0, len(sigReqs)) + + for label, sigReq := range sigReqs { + key, err := h.kr.ResolveKey(req.Context(), sigReq.keyID) + if err != nil { + h.l.Logf(req.Context(), "Failed resolving key for http signature response: %v.", err) + + return nil, true + } + + if key.Algorithm != sigReq.alg { + h.l.Logf(req.Context(), "Requested key %s does not support requested algorithm %s.", + sigReq.keyID, sigReq.alg) + + return nil, true + } + + opts := []SignerOption{ + WithTag(sigReq.tag), + WithLabel(label), + //withComponents(sigReq.identifiers), + WithNonce(NonceGetterFunc(func(_ context.Context) (string, error) { + return sigReq.nonce, nil + })), + } + + if !sigReq.expires { + opts = append(opts, WithTTL(0)) + } + + sgnr, err := NewSigner(key, opts...) + if err != nil { + h.l.Logf(req.Context(), "Failed resolving key for http signature response: %v.", err) + + return nil, true + } + + signer = append(signer, sgnr) + } + + return signer, false +} + +func getSignatureRequirements(values []string) (map[string]*signatureRequirements, error) { + inputDict, err := httpsfv.UnmarshalDictionary(values) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrMalformedData, err) + } + + sigRefs := make(map[string]*signatureRequirements, len(inputDict.Names())) + + for _, label := range inputDict.Names() { + var sr signatureRequirements + + m, _ := inputDict.Get(label) + + sigReqs, ok := m.(httpsfv.InnerList) + if !ok { + return nil, fmt.Errorf("%w: unexpected signature requirements format", ErrMalformedData) + } + + if err = sr.fromInnerList(sigReqs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrMalformedData, err) + } + + sigRefs[label] = &sr + } + + return sigRefs, nil +} + +func NewVerifierMiddleware(verifier Verifier, opts ...HandlerOption) func(http.Handler) http.Handler { + hdl := &handler{ + v: verifier, + l: noopLogger{}, + ec: http.StatusBadRequest, + } + + for _, opt := range opts { + opt(hdl) + } + + if hdl.s != nil && hdl.kr != nil { + panic("WithSignedResponses and WithSignedResponseNegotiation are mutually exclusive") + } + + return func(next http.Handler) http.Handler { + return hdl.wrap(next) + } +} + +type responseWriterAdapter struct { + s Signer + l Logger + rw http.ResponseWriter + req *http.Request + + msgSignedOrInProgress bool +} + +func newResponseWriterWrapper( + logger Logger, + signer Signer, + rw http.ResponseWriter, + req *http.Request, +) http.ResponseWriter { + if signer == nil { + return rw + } + + rwa := &responseWriterAdapter{ + l: logger, + s: signer, + req: req, + rw: rw, + } + + return httpsnoop.Wrap( + rw, + httpsnoop.Hooks{ + Flush: rwa.flush, + Write: rwa.write, + WriteHeader: rwa.writeHeader, + }, + ) +} + +func (a *responseWriterAdapter) sign(rw http.ResponseWriter, req *http.Request, data []byte, code int) error { + hdr, err := a.s.Sign(MessageForResponse(req, rw.Header(), data, code)) + if err != nil { + return err + } + + if len(hdr) == 0 { + return nil + } + + rw.Header().Set("Signature-Input", hdr.Get("Signature-Input")) + rw.Header().Set("Signature", hdr.Get("Signature")) + rw.Header().Add("Vary", "Signature-Input") + rw.Header().Add("Vary", "Signature") + + return nil +} + +func (a *responseWriterAdapter) flush(flush httpsnoop.FlushFunc) httpsnoop.FlushFunc { + return func() { + if a.msgSignedOrInProgress { + flush() + + return + } + + a.msgSignedOrInProgress = true + + if err := a.sign(a.rw, a.req, nil, http.StatusOK); err != nil { + a.l.Logf(a.req.Context(), "Failed signing http response: %v", err) + a.rw.WriteHeader(http.StatusInternalServerError) + + return + } + + flush() + } +} + +func (a *responseWriterAdapter) write(write httpsnoop.WriteFunc) httpsnoop.WriteFunc { + return func(data []byte) (int, error) { + if a.msgSignedOrInProgress { + return write(data) + } + + a.msgSignedOrInProgress = true + + if err := a.sign(a.rw, a.req, data, http.StatusOK); err != nil { + a.l.Logf(a.req.Context(), "Failed signing http response: %v", err) + + return 0, err + } + + return write(data) + } +} + +func (a *responseWriterAdapter) writeHeader(writeHeader httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { + return func(code int) { + if a.msgSignedOrInProgress { + writeHeader(code) + + return + } + + a.msgSignedOrInProgress = true + + if err := a.sign(a.rw, a.req, nil, code); err != nil { + a.l.Logf(a.req.Context(), "Failed signing http response: %v", err) + writeHeader(http.StatusInternalServerError) + + return + } + + writeHeader(code) + } +} + +type compositeSigner []Signer + +func (c compositeSigner) Sign(msg *Message) (http.Header, error) { + var ( + hdr http.Header + err error + ) + + for _, signer := range c { + hdr, err = signer.Sign(msg) + if err != nil { + return nil, err + } + } + + return hdr, nil +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..7aa6537 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,98 @@ +package httpsig + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSignatureMiddleware(t *testing.T) { + t.Parallel() + + pkp256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + sig, err := NewSigner( + Key{Key: pkp256, KeyID: "test", Algorithm: EcdsaP256Sha256}, + WithComponents("@authority", "@method"), + WithTTL(5*time.Second), + WithTag("test"), + ) + require.NoError(t, err) + + for _, tc := range []struct { + uc string + opts []HandlerOption + ver Verifier + assert func(t *testing.T, resp *http.Response, innerCalled bool) + }{ + { + uc: "without error", + ver: func() Verifier { + ver, err := NewVerifier(Key{Key: &pkp256.PublicKey, KeyID: "test", Algorithm: EcdsaP256Sha256}, + WithRequiredComponents("@authority", "@method"), + WithValidityTolerance(1*time.Second), + WithRequiredTag("test"), + ) + require.NoError(t, err) + + return ver + }(), + assert: func(t *testing.T, resp *http.Response, innerCalled bool) { + t.Helper() + + require.True(t, innerCalled) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }, + }, + { + uc: "with error, default error handler", + ver: func() Verifier { + ver, err := NewVerifier(Key{Key: &pkp256.PublicKey, KeyID: "test", Algorithm: EcdsaP256Sha256}, + WithRequiredComponents("@authority", "@path"), + WithRequiredTag("test"), + ) + require.NoError(t, err) + + return ver + }(), + assert: func(t *testing.T, resp *http.Response, innerCalled bool) { + t.Helper() + + require.False(t, innerCalled) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }, + }, + } { + t.Run(tc.uc, func(t *testing.T) { + var handlerCalled bool + + middleware := NewVerifierMiddleware(tc.ver, tc.opts...) + + srv := httptest.NewServer(middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + handlerCalled = true + }))) + + defer srv.Close() + + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + + client := http.Client{Transport: NewTransport(http.DefaultTransport, sig)} + resp, err := client.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + + tc.assert(t, resp, handlerCalled) + }) + } +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..cd303c1 --- /dev/null +++ b/transport.go @@ -0,0 +1,24 @@ +package httpsig + +import ( + "net/http" +) + +// NewTransport returns a new client transport that wraps the provided transport with +// http message signing and verifying. +func NewTransport(inner http.RoundTripper, signer Signer) http.RoundTripper { + return rt(func(req *http.Request) (*http.Response, error) { + hdr, err := signer.Sign(MessageFromRequest(req)) + if err != nil { + return nil, err + } + + req.Header = hdr + + return inner.RoundTrip(req) + }) +} + +type rt func(*http.Request) (*http.Response, error) + +func (r rt) RoundTrip(req *http.Request) (*http.Response, error) { return r(req) } diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 0000000..cfb7ad2 --- /dev/null +++ b/transport_test.go @@ -0,0 +1,74 @@ +package httpsig + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dunglas/httpsfv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTransport(t *testing.T) { + t.Parallel() + + var receivedHeader http.Header + + pkp256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + receivedHeader = req.Header.Clone() + + rw.WriteHeader(http.StatusOK) + })) + + defer srv.Close() + + sig, err := NewSigner( + Key{Key: pkp256, KeyID: "test", Algorithm: EcdsaP256Sha256}, + WithComponents("@authority", "@method"), + WithTTL(5*time.Second), + WithTag("test"), + ) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + + client := http.Client{Transport: NewTransport(http.DefaultTransport, sig)} + resp, err := client.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + + inputDict, err := httpsfv.UnmarshalDictionary(receivedHeader.Values(headerSignatureInput)) + require.NoError(t, err) + + require.Contains(t, inputDict.Names(), "sig") + member, present := inputDict.Get("sig") + require.True(t, present) + + require.IsType(t, httpsfv.InnerList{}, member) + list := member.(httpsfv.InnerList) + require.Len(t, list.Items, 2) + assert.Equal(t, "@authority", list.Items[0].Value) + assert.Equal(t, "@method", list.Items[1].Value) + assert.ElementsMatch(t, list.Params.Names(), []string{"created", "expires", "keyid", "alg", "nonce", "tag"}) + + sigDict, err := httpsfv.UnmarshalDictionary(receivedHeader.Values(headerSignature)) + require.NoError(t, err) + + require.Contains(t, sigDict.Names(), "sig") + member, present = sigDict.Get("sig") + require.True(t, present) + require.IsType(t, httpsfv.Item{}, member) + item := member.(httpsfv.Item) + require.IsType(t, []byte{}, item.Value) +}