Files
portainer/api/http/csrf/csrf.go

160 lines
4.6 KiB
Go

package csrf
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"os"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/pkg/featureflags"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
gcsrf "github.com/gorilla/csrf"
"github.com/rs/zerolog/log"
"github.com/urfave/negroni"
)
const csrfSkipHeader = "X-CSRF-Token-Skip"
// SkipCSRFToken signals that the X-CSRF-Token header should not be sent in the response.
// Deprecated: only meaningful when the "legacy-csrf" feature flag is enabled.
func SkipCSRFToken(w http.ResponseWriter) {
w.Header().Set(csrfSkipHeader, "1")
}
func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) {
// DOCKER_EXTENSION=1 is set in build/docker-extension/docker-compose.yml
isDockerDesktopExtension := false
if val, ok := os.LookupEnv("DOCKER_EXTENSION"); ok && val == "1" {
isDockerDesktopExtension = true
}
if featureflags.IsEnabled("legacy-csrf") {
return withLegacyProtect(handler, trustedOrigins, isDockerDesktopExtension)
}
cop := http.NewCrossOriginProtection()
for _, origin := range trustedOrigins {
if err := cop.AddTrustedOrigin(origin); err != nil {
return nil, fmt.Errorf("failed to add trusted origin %q: %w", origin, err)
}
}
cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Error().Err(cop.Check(r)).
Str("request_url", r.URL.String()).
Str("host", r.Host).
Str("origin", r.Header.Get("Origin")).
Str("sec_fetch_site", r.Header.Get("Sec-Fetch-Site")).
Strs("trusted_origins", trustedOrigins).
Msg("CSRF check failed")
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}))
protected := cop.Handler(handler)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
skip, err := security.ShouldSkipCSRFCheck(r, isDockerDesktopExtension)
if err != nil {
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
return
}
if skip {
handler.ServeHTTP(w, r)
return
}
protected.ServeHTTP(w, r)
}), nil
}
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacyProtect(handler http.Handler, trustedOrigins []string, isDockerDesktopExtension bool) (http.Handler, error) {
handler = withLegacySendCSRFToken(handler)
token := make([]byte, 32)
if _, err := rand.Read(token); err != nil {
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
}
handler = gcsrf.Protect(
token,
gcsrf.Path("/"),
gcsrf.Secure(false),
gcsrf.TrustedOrigins(trustedOrigins),
gcsrf.ErrorHandler(withLegacyErrorHandler(trustedOrigins)),
)(handler)
return withLegacySkipCSRF(handler, isDockerDesktopExtension), nil
}
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacySendCSRFToken(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sw := negroni.NewResponseWriter(w)
sw.Before(func(sw negroni.ResponseWriter) {
if len(sw.Header().Get(csrfSkipHeader)) > 0 {
sw.Header().Del(csrfSkipHeader)
return
}
if statusCode := sw.Status(); statusCode >= 200 && statusCode < 300 {
sw.Header().Set("X-CSRF-Token", gcsrf.Token(r))
}
})
handler.ServeHTTP(sw, r)
})
}
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacySkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
skip, err := security.ShouldSkipCSRFCheck(r, isDockerDesktopExtension)
if err != nil {
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
return
}
if skip {
r = gcsrf.UnsafeSkipCheck(r)
}
handler.ServeHTTP(w, r)
})
}
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacyErrorHandler(trustedOrigins []string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := gcsrf.FailureReason(r)
if errors.Is(err, gcsrf.ErrBadOrigin) || errors.Is(err, gcsrf.ErrBadReferer) || errors.Is(err, gcsrf.ErrNoReferer) {
log.Error().Err(err).
Str("request_url", r.URL.String()).
Str("host", r.Host).
Str("x_forwarded_proto", r.Header.Get("X-Forwarded-Proto")).
Str("forwarded", r.Header.Get("Forwarded")).
Str("origin", r.Header.Get("Origin")).
Str("referer", r.Header.Get("Referer")).
Strs("trusted_origins", trustedOrigins).
Msg("Failed to validate Origin or Referer")
}
http.Error(
w,
http.StatusText(http.StatusForbidden)+" - "+err.Error(),
http.StatusForbidden,
)
})
}