package gig

import (

"crypto/tls"

"crypto/x509"

"crypto/x509/pkix"

"testing"

"github.com/matryer/is"

)

func TestCertAuth(t *testing.T) {

is := is.New(t)

g := New()

testCases := []struct {

mw MiddlewareFunc

expectedErrNoCert error

expectedErrBadCert error

name string

}{

{

mw: CertAuth(ValidateHasCertificate),

expectedErrNoCert: ErrClientCertificateRequired,

expectedErrBadCert: nil,

name: `ValidateHasCertificate`,

},

{

mw: CertAuth(func(cert *x509.Certificate, c Context) *GeminiError {

if cert == nil {

return ErrClientCertificateRequired

}

if cert.Subject.CommonName != "tester" {

return ErrCertificateNotValid

}

c.Set("subject", cert.Subject.CommonName)

return nil

}),

expectedErrNoCert: ErrClientCertificateRequired,

expectedErrBadCert: ErrCertificateNotValid,

name: `CustomValidator`,

},

{

mw: CertAuthWithConfig(CertAuthConfig{

Skipper: nil,

Validator: nil,

}),

expectedErrNoCert: ErrClientCertificateRequired,

expectedErrBadCert: nil,

name: `NilConfig`,

},

{

mw: CertAuthWithConfig(CertAuthConfig{

Skipper: func(c Context) bool {

c.Set("subject", "tester")

return true

},

}),

expectedErrNoCert: nil,

expectedErrBadCert: nil,

name: `CustomSkipper`,

},

}

for _, test := range testCases {

test := test

t.Run(test.name, func(t *testing.T) {

h := test.mw(func(c Context) error {

return c.Gemini("test")

})

// No certificate

c, _ := g.NewFakeContext("/", nil)

is.Equal(h(c), test.expectedErrNoCert)

// Invalid certificate

c, _ = g.NewFakeContext("/", &tls.ConnectionState{

PeerCertificates: []*x509.Certificate{

{Subject: pkix.Name{CommonName: "not-tester"}},

},

})

is.Equal(h(c), test.expectedErrBadCert)

// Valid certificate

c, _ = g.NewFakeContext("/", &tls.ConnectionState{

PeerCertificates: []*x509.Certificate{

{Subject: pkix.Name{CommonName: "tester"}},

},

})

is.NoErr(h(c))

is.Equal("tester", c.Get("subject"))

})

}

}


Source