diff --git a/http_handlers.go b/http_handlers.go index 9221d21..aba0b41 100644 --- a/http_handlers.go +++ b/http_handlers.go @@ -20,6 +20,7 @@ func RegisterEchoHandlers(svcHandler ServiceHandler, e *echo.Echo) { if err != nil { return c.String(http.StatusInternalServerError, err.Error()) } + c.Response().Header().Set("Content-Transfer-Encoding", "base64") return c.Blob(200, "application/pkcs7-mime", certs) }) e.POST("/.well-known/est/simpleenroll", func(c echo.Context) error { @@ -38,7 +39,13 @@ func RegisterEchoHandlers(svcHandler ServiceHandler, e *echo.Echo) { } return c.String(http.StatusInternalServerError, err.Error()) } - return c.Blob(http.StatusCreated, "application/pkcs7-mime", bytes) + c.Response().Header().Set("Content-Transfer-Encoding", "base64") + if c.Request().UserAgent() == "fioconfig-client/2" { + // Older versions of fioconfig are requiring status code 201 and were not ignoring an optional `smime-type` extension. + // Thus, we have to return whatever it expects for backward compatibility with devices already deployed. + return c.Blob(http.StatusCreated, "application/pkcs7-mime", bytes) + } + return c.Blob(http.StatusOK, "application/pkcs7-mime; smime-type=certs-only", bytes) }) e.POST("/.well-known/est/simplereenroll", func(c echo.Context) error { svc, err := svcHandler.GetService(c.Request().Context(), c.Request().TLS.ServerName) @@ -57,7 +64,13 @@ func RegisterEchoHandlers(svcHandler ServiceHandler, e *echo.Echo) { } return c.String(http.StatusInternalServerError, err.Error()) } - return c.Blob(http.StatusCreated, "application/pkcs7-mime", bytes) + c.Response().Header().Set("Content-Transfer-Encoding", "base64") + if c.Request().UserAgent() == "fioconfig-client/2" { + // Older versions of fioconfig are requiring status code 201 and were not ignoring an optional `smime-type` extension. + // Thus, we have to return whatever it expects for backward compatibility with devices already deployed. + return c.Blob(http.StatusCreated, "application/pkcs7-mime", bytes) + } + return c.Blob(http.StatusOK, "application/pkcs7-mime; smime-type=certs-only", bytes) }) } diff --git a/http_test.go b/http_test.go index 525a4fe..fb3ad0d 100644 --- a/http_test.go +++ b/http_test.go @@ -18,10 +18,23 @@ import ( "go.mozilla.org/pkcs7" ) +const mimeTypePKCS7 = "application/pkcs7-mime" +const mimeTypePKCS7CertsOnly = "application/pkcs7-mime; smime-type=certs-only" + type testClient struct { svc Service srv *httptest.Server ctx context.Context + old bool +} + +type responseChecker func(t *testing.T, res *http.Response) + +func checkHeaderValue(header string, expectedvalue string) responseChecker { + var checker responseChecker = func(t *testing.T, res *http.Response) { + require.Equal(t, expectedvalue, res.Header.Get(header), "Unexpected content-type") + } + return checker } func (tc testClient) GET(t *testing.T, resource string) []byte { @@ -34,7 +47,7 @@ func (tc testClient) GET(t *testing.T, resource string) []byte { return buf } -func (tc testClient) POST(t *testing.T, resource string, data []byte, cert *tls.Certificate) (int, []byte) { +func (tc testClient) POST(t *testing.T, resource string, data []byte, cert *tls.Certificate, additionalChecks ...responseChecker) (int, []byte) { url := tc.srv.URL + resource client := tc.srv.Client() if cert != nil { @@ -42,10 +55,20 @@ func (tc testClient) POST(t *testing.T, resource string, data []byte, cert *tls. transport.TLSClientConfig.Certificates = []tls.Certificate{*cert} } - res, err := client.Post(url, "application/pkcs10", bytes.NewBuffer(data)) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) + require.Nil(t, err) + req.Header.Set("Content-Type", "application/pkcs10") + if tc.old { + req.Header.Set("User-Agent", "fioconfig-client/2") + } + res, err := client.Do(req) + require.Nil(t, err) buf, err := io.ReadAll(res.Body) require.Nil(t, err) + for _, check := range additionalChecks { + check(t, res) + } return res.StatusCode, buf } @@ -122,8 +145,14 @@ func TestSimpleEnroll(t *testing.T) { require.Equal(t, "The CSR could not be decoded: asn1: syntax error: sequence truncated", string(data)) _, csr := createB64CsrDer(t, cn) - rc, data = tc.POST(t, "/.well-known/est/simpleenroll", csr, kp) + rc, data = tc.POST(t, "/.well-known/est/simpleenroll", csr, kp, checkHeaderValue("content-type", mimeTypePKCS7CertsOnly)) + require.Equal(t, 200, rc, string(data)) + + // backward compatablity test + tc.old = true + rc, data = tc.POST(t, "/.well-known/est/simpleenroll", csr, kp, checkHeaderValue("content-type", mimeTypePKCS7)) require.Equal(t, 201, rc, string(data)) + tc.old = false buf, err := base64.StdEncoding.DecodeString(string(data)) require.Nil(t, err) @@ -158,8 +187,14 @@ func TestSimpleReEnroll(t *testing.T) { require.Equal(t, "The CSR could not be decoded: asn1: syntax error: sequence truncated", string(data)) newkey, csr := createB64CsrDer(t, cn) - rc, data = tc.POST(t, "/.well-known/est/simplereenroll", csr, kp) + rc, data = tc.POST(t, "/.well-known/est/simplereenroll", csr, kp, checkHeaderValue("content-type", mimeTypePKCS7CertsOnly)) + require.Equal(t, 200, rc, string(data)) + + // backward compatablity test + tc.old = true + rc, data = tc.POST(t, "/.well-known/est/simpleenroll", csr, kp, checkHeaderValue("content-type", mimeTypePKCS7)) require.Equal(t, 201, rc, string(data)) + tc.old = false buf, err := base64.StdEncoding.DecodeString(string(data)) require.Nil(t, err) @@ -174,8 +209,8 @@ func TestSimpleReEnroll(t *testing.T) { PrivateKey: newkey, } - rc, data = tc.POST(t, "/.well-known/est/simplereenroll", csr, kp) - require.Equal(t, 201, rc, string(data)) + rc, data = tc.POST(t, "/.well-known/est/simplereenroll", csr, kp, checkHeaderValue("content-type", mimeTypePKCS7CertsOnly)) + require.Equal(t, 200, rc, string(data)) }) }