Compare commits

..

14 Commits

Author SHA1 Message Date
Malcolm Lockyer cc7f7008cb chore: bump version to 2.27.9 (#852) 2025-07-02 15:21:45 +12:00
andres-portainer 1e1998e269 feat(csrf): add trusted origins cli flags [BE-11972] (#839)
Co-authored-by: oscarzhou <oscar.zhou@portainer.io>
Co-authored-by: andres-portainer <andres-portainer@users.noreply.github.com>
2025-07-01 21:38:02 -03:00
Steven Kang 973c99dcf4 bump version to 2.27.8 (#824) 2025-06-25 09:43:58 +12:00
Cara Ryan 7fd5b96130 fix(kubernetes): Namespace access permission changes role bindings not created [R8S-366] (#807)
Co-authored-by: andres-portainer <andres-portainer@users.noreply.github.com>
Co-authored-by: Malcolm Lockyer <segfault88@users.noreply.github.com>
2025-06-25 09:24:18 +12:00
Steven Kang ee6d33365e bump version to 2.27.7 (#804) 2025-06-17 09:43:15 +12:00
Steven Kang e115055a1b security: cve-2025-22874 & cve-2025-22871 bump go to 1.23.10 (#799) 2025-06-12 17:30:49 +12:00
Devon Steenberg 384cb53c64 fix(proxy): whitelist headers for proxy to forward [BE-11819] (#760) 2025-05-30 11:49:41 +12:00
Oscar Zhou 4240cbf029 fix(csrf): skip trustedorigin for http request and check x-forwarded-proto for reverse proxy [BE-11832] (#713) 2025-05-09 13:45:33 +12:00
Steven Kang eb28dd4f4e chore: bump version to 2.27.6 (#720) 2025-05-09 09:57:27 +12:00
Steven Kang 78127f8f3d chore: bump version to 2.27.5 (#704) 2025-05-02 10:08:56 +12:00
Oscar Zhou c474322889 fix(dependencies): downgrade gorilla/csrf to v1.7.2 [BE-11832] (#689) 2025-04-24 12:14:18 +12:00
Oscar Zhou 83527da1a8 fix: cve-2025-22871 [BE-11825] (#677) 2025-04-22 21:29:18 +12:00
Oscar Zhou 7c8bef84b1 feat(docker): backport --pull-limit-check-disabled cli flag [BE-11820] (#658) 2025-04-16 19:28:43 +12:00
Steven Kang 5b3dba130b chore: bump version to 2.27.4 (#645) 2025-04-15 10:24:20 +12:00
22 changed files with 1230 additions and 78 deletions
+2
View File
@@ -60,6 +60,8 @@ func CLIFlags() *portainer.CLIFlags {
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(),
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
}
}
+17 -16
View File
@@ -4,20 +4,21 @@
package cli
const (
defaultBindAddress = ":9000"
defaultHTTPSBindAddress = ":9443"
defaultTunnelServerAddress = "0.0.0.0"
defaultTunnelServerPort = "8000"
defaultDataDirectory = "/data"
defaultAssetsDirectory = "./"
defaultTLS = "false"
defaultTLSSkipVerify = "false"
defaultTLSCACertPath = "/certs/ca.pem"
defaultTLSCertPath = "/certs/cert.pem"
defaultTLSKeyPath = "/certs/key.pem"
defaultHTTPDisabled = "false"
defaultHTTPEnabled = "false"
defaultSSL = "false"
defaultBaseURL = "/"
defaultSecretKeyName = "portainer"
defaultBindAddress = ":9000"
defaultHTTPSBindAddress = ":9443"
defaultTunnelServerAddress = "0.0.0.0"
defaultTunnelServerPort = "8000"
defaultDataDirectory = "/data"
defaultAssetsDirectory = "./"
defaultTLS = "false"
defaultTLSSkipVerify = "false"
defaultTLSCACertPath = "/certs/ca.pem"
defaultTLSCertPath = "/certs/cert.pem"
defaultTLSKeyPath = "/certs/key.pem"
defaultHTTPDisabled = "false"
defaultHTTPEnabled = "false"
defaultSSL = "false"
defaultBaseURL = "/"
defaultSecretKeyName = "portainer"
defaultPullLimitCheckDisabled = "false"
)
+18 -17
View File
@@ -1,21 +1,22 @@
package cli
const (
defaultBindAddress = ":9000"
defaultHTTPSBindAddress = ":9443"
defaultTunnelServerAddress = "0.0.0.0"
defaultTunnelServerPort = "8000"
defaultDataDirectory = "C:\\data"
defaultAssetsDirectory = "./"
defaultTLS = "false"
defaultTLSSkipVerify = "false"
defaultTLSCACertPath = "C:\\certs\\ca.pem"
defaultTLSCertPath = "C:\\certs\\cert.pem"
defaultTLSKeyPath = "C:\\certs\\key.pem"
defaultHTTPDisabled = "false"
defaultHTTPEnabled = "false"
defaultSSL = "false"
defaultSnapshotInterval = "5m"
defaultBaseURL = "/"
defaultSecretKeyName = "portainer"
defaultBindAddress = ":9000"
defaultHTTPSBindAddress = ":9443"
defaultTunnelServerAddress = "0.0.0.0"
defaultTunnelServerPort = "8000"
defaultDataDirectory = "C:\\data"
defaultAssetsDirectory = "./"
defaultTLS = "false"
defaultTLSSkipVerify = "false"
defaultTLSCACertPath = "C:\\certs\\ca.pem"
defaultTLSCertPath = "C:\\certs\\cert.pem"
defaultTLSKeyPath = "C:\\certs\\key.pem"
defaultHTTPDisabled = "false"
defaultHTTPEnabled = "false"
defaultSSL = "false"
defaultSnapshotInterval = "5m"
defaultBaseURL = "/"
defaultSecretKeyName = "portainer"
defaultPullLimitCheckDisabled = "false"
)
+15
View File
@@ -50,6 +50,7 @@ import (
"github.com/portainer/portainer/pkg/featureflags"
"github.com/portainer/portainer/pkg/libhelm"
"github.com/portainer/portainer/pkg/libstack/compose"
"github.com/portainer/portainer/pkg/validate"
"github.com/gofrs/uuid"
"github.com/rs/zerolog/log"
@@ -328,6 +329,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags)
}
trustedOrigins := []string{}
if *flags.TrustedOrigins != "" {
// validate if the trusted origins are valid urls
for _, origin := range strings.Split(*flags.TrustedOrigins, ",") {
if !validate.IsTrustedOrigin(origin) {
log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.")
}
trustedOrigins = append(trustedOrigins, origin)
}
}
fileService := initFileService(*flags.Data)
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
if encryptionKey == nil {
@@ -575,6 +588,8 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
AdminCreationDone: adminCreationDone,
PendingActionsService: pendingActionsService,
PlatformService: platformService,
PullLimitCheckDisabled: *flags.PullLimitCheckDisabled,
TrustedOrigins: trustedOrigins,
}
}
@@ -610,7 +610,7 @@
"RequiredPasswordLength": 12
},
"KubeconfigExpiry": "0",
"KubectlShellImage": "portainer/kubectl-shell:2.27.4",
"KubectlShellImage": "portainer/kubectl-shell:2.27.9",
"LDAPSettings": {
"AnonymousMode": true,
"AutoCreateUsers": true,
@@ -943,7 +943,7 @@
}
],
"version": {
"VERSION": "{\"SchemaVersion\":\"2.27.4\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
"VERSION": "{\"SchemaVersion\":\"2.27.9\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
},
"webhooks": null
}
+35 -7
View File
@@ -2,6 +2,7 @@ package csrf
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"os"
@@ -9,7 +10,8 @@ import (
"github.com/portainer/portainer/api/http/security"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
gorillacsrf "github.com/gorilla/csrf"
gcsrf "github.com/gorilla/csrf"
"github.com/rs/zerolog/log"
"github.com/urfave/negroni"
)
@@ -19,7 +21,7 @@ func SkipCSRFToken(w http.ResponseWriter) {
w.Header().Set(csrfSkipHeader, "1")
}
func WithProtect(handler http.Handler) (http.Handler, error) {
func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) {
// IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck)
// DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml
isDockerDesktopExtension := false
@@ -34,10 +36,12 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
}
handler = gorillacsrf.Protect(
handler = gcsrf.Protect(
token,
gorillacsrf.Path("/"),
gorillacsrf.Secure(false),
gcsrf.Path("/"),
gcsrf.Secure(false),
gcsrf.TrustedOrigins(trustedOrigins),
gcsrf.ErrorHandler(withErrorHandler(trustedOrigins)),
)(handler)
return withSkipCSRF(handler, isDockerDesktopExtension), nil
@@ -55,7 +59,7 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
}
if statusCode := sw.Status(); statusCode >= 200 && statusCode < 300 {
sw.Header().Set("X-CSRF-Token", gorillacsrf.Token(r))
sw.Header().Set("X-CSRF-Token", gcsrf.Token(r))
}
})
@@ -73,9 +77,33 @@ func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Hand
}
if skip {
r = gorillacsrf.UnsafeSkipCheck(r)
r = gcsrf.UnsafeSkipCheck(r)
}
handler.ServeHTTP(w, r)
})
}
func withErrorHandler(trustedOrigins []string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := gcsrf.FailureReason(r)
if errors.Is(err, gcsrf.ErrBadOrigin) || errors.Is(err, gcsrf.ErrBadReferer) || errors.Is(err, gcsrf.ErrNoReferer) {
log.Error().Err(err).
Str("request_url", r.URL.String()).
Str("host", r.Host).
Str("x_forwarded_proto", r.Header.Get("X-Forwarded-Proto")).
Str("forwarded", r.Header.Get("Forwarded")).
Str("origin", r.Header.Get("Origin")).
Str("referer", r.Header.Get("Referer")).
Strs("trusted_origins", trustedOrigins).
Msg("Failed to validate Origin or Referer")
}
http.Error(
w,
http.StatusText(http.StatusForbidden)+" - "+err.Error(),
http.StatusForbidden,
)
})
}
@@ -80,6 +80,13 @@ func (handler *Handler) endpointDockerhubStatus(w http.ResponseWriter, r *http.R
}
}
if handler.PullLimitCheckDisabled {
return response.JSON(w, &dockerhubStatusResponse{
Limit: 10,
Remaining: 10,
})
}
httpClient := client.NewHTTPClient()
token, err := getDockerHubToken(httpClient, registry)
if err != nil {
@@ -75,7 +75,7 @@ func (handler *Handler) listRegistries(tx dataservices.DataStoreTx, r *http.Requ
return nil, httperror.InternalServerError("Unable to retrieve registries from the database", err)
}
registries, handleError := handler.filterRegistriesByAccess(r, registries, endpoint, user, securityContext.UserMemberships)
registries, handleError := handler.filterRegistriesByAccess(tx, r, registries, endpoint, user, securityContext.UserMemberships)
if handleError != nil {
return nil, handleError
}
@@ -87,15 +87,15 @@ func (handler *Handler) listRegistries(tx dataservices.DataStoreTx, r *http.Requ
return registries, err
}
func (handler *Handler) filterRegistriesByAccess(r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User, memberships []portainer.TeamMembership) ([]portainer.Registry, *httperror.HandlerError) {
func (handler *Handler) filterRegistriesByAccess(tx dataservices.DataStoreTx, r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User, memberships []portainer.TeamMembership) ([]portainer.Registry, *httperror.HandlerError) {
if !endpointutils.IsKubernetesEndpoint(endpoint) {
return security.FilterRegistries(registries, user, memberships, endpoint.ID), nil
}
return handler.filterKubernetesEndpointRegistries(r, registries, endpoint, user, memberships)
return handler.filterKubernetesEndpointRegistries(tx, r, registries, endpoint, user, memberships)
}
func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User, memberships []portainer.TeamMembership) ([]portainer.Registry, *httperror.HandlerError) {
func (handler *Handler) filterKubernetesEndpointRegistries(tx dataservices.DataStoreTx, r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User, memberships []portainer.TeamMembership) ([]portainer.Registry, *httperror.HandlerError) {
namespaceParam, _ := request.RetrieveQueryParameter(r, "namespace", true)
isAdmin, err := security.IsAdmin(r)
if err != nil {
@@ -116,7 +116,7 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi
return registries, nil
}
return handler.filterKubernetesRegistriesByUserRole(r, registries, endpoint, user)
return handler.filterKubernetesRegistriesByUserRole(tx, r, registries, endpoint, user)
}
func (handler *Handler) isNamespaceAuthorized(endpoint *portainer.Endpoint, namespace string, userId portainer.UserID, memberships []portainer.TeamMembership, isAdmin bool) (bool, error) {
@@ -169,7 +169,7 @@ func registryAccessPoliciesContainsNamespace(registryAccess portainer.RegistryAc
return false
}
func (handler *Handler) filterKubernetesRegistriesByUserRole(r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User) ([]portainer.Registry, *httperror.HandlerError) {
func (handler *Handler) filterKubernetesRegistriesByUserRole(tx dataservices.DataStoreTx, r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User) ([]portainer.Registry, *httperror.HandlerError) {
err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
if errors.Is(err, security.ErrAuthorizationRequired) {
return nil, httperror.Forbidden("User is not authorized", err)
@@ -178,7 +178,7 @@ func (handler *Handler) filterKubernetesRegistriesByUserRole(r *http.Request, re
return nil, httperror.InternalServerError("Unable to retrieve info from request context", err)
}
userNamespaces, err := handler.userNamespaces(endpoint, user)
userNamespaces, err := handler.userNamespaces(tx, endpoint, user)
if err != nil {
return nil, httperror.InternalServerError("unable to retrieve user namespaces", err)
}
@@ -186,7 +186,7 @@ func (handler *Handler) filterKubernetesRegistriesByUserRole(r *http.Request, re
return filterRegistriesByNamespaces(registries, endpoint.ID, userNamespaces), nil
}
func (handler *Handler) userNamespaces(endpoint *portainer.Endpoint, user *portainer.User) ([]string, error) {
func (handler *Handler) userNamespaces(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, user *portainer.User) ([]string, error) {
kcl, err := handler.K8sClientFactory.GetPrivilegedKubeClient(endpoint)
if err != nil {
return nil, err
@@ -197,7 +197,7 @@ func (handler *Handler) userNamespaces(endpoint *portainer.Endpoint, user *porta
return nil, err
}
userMemberships, err := handler.DataStore.TeamMembership().TeamMembershipsByUserID(user.ID)
userMemberships, err := tx.TeamMembership().TeamMembershipsByUserID(user.ID)
if err != nil {
return nil, err
}
+14 -13
View File
@@ -26,19 +26,20 @@ func hideFields(endpoint *portainer.Endpoint) {
// Handler is the HTTP handler used to handle environment(endpoint) operations.
type Handler struct {
*mux.Router
requestBouncer security.BouncerService
DataStore dataservices.DataStore
FileService portainer.FileService
ProxyManager *proxy.Manager
ReverseTunnelService portainer.ReverseTunnelService
SnapshotService portainer.SnapshotService
K8sClientFactory *cli.ClientFactory
ComposeStackManager portainer.ComposeStackManager
AuthorizationService *authorization.Service
DockerClientFactory *dockerclient.ClientFactory
BindAddress string
BindAddressHTTPS string
PendingActionsService *pendingactions.PendingActionsService
requestBouncer security.BouncerService
DataStore dataservices.DataStore
FileService portainer.FileService
ProxyManager *proxy.Manager
ReverseTunnelService portainer.ReverseTunnelService
SnapshotService portainer.SnapshotService
K8sClientFactory *cli.ClientFactory
ComposeStackManager portainer.ComposeStackManager
AuthorizationService *authorization.Service
DockerClientFactory *dockerclient.ClientFactory
BindAddress string
BindAddressHTTPS string
PendingActionsService *pendingactions.PendingActionsService
PullLimitCheckDisabled bool
}
// NewHandler creates a handler to manage environment(endpoint) operations.
+1 -1
View File
@@ -81,7 +81,7 @@ type Handler struct {
}
// @title PortainerCE API
// @version 2.27.4
// @version 2.27.9
// @description.markdown api-description.md
// @termsOfService
@@ -0,0 +1,76 @@
package middlewares
import (
"net/http"
"slices"
"strings"
"github.com/gorilla/csrf"
)
var (
// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
)
type plainTextHTTPRequestHandler struct {
next http.Handler
}
// parseForwardedHeaderProto parses the Forwarded header and extracts the protocol.
// The Forwarded header format supports:
// - Single proxy: Forwarded: by=<identifier>;for=<identifier>;host=<host>;proto=<http|https>
// - Multiple proxies: Forwarded: for=192.0.2.43, for=198.51.100.17
// We take the first (leftmost) entry as it represents the original client
func parseForwardedHeaderProto(forwarded string) string {
if forwarded == "" {
return ""
}
// Parse the first part (leftmost proxy, closest to original client)
firstPart, _, _ := strings.Cut(forwarded, ",")
firstPart = strings.TrimSpace(firstPart)
// Split by semicolon to get key-value pairs within this proxy entry
// Format: key=value;key=value;key=value
pairs := strings.Split(firstPart, ";")
for _, pair := range pairs {
// Split by equals sign to separate key and value
key, value, found := strings.Cut(pair, "=")
if !found {
continue
}
if strings.EqualFold(strings.TrimSpace(key), "proto") {
return strings.Trim(strings.TrimSpace(value), `"'`)
}
}
return ""
}
// isHTTPSRequest checks if the original request was made over HTTPS
// by examining both X-Forwarded-Proto and Forwarded headers
func isHTTPSRequest(r *http.Request) bool {
return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") ||
strings.EqualFold(parseForwardedHeaderProto(r.Header.Get("Forwarded")), "https")
}
func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if slices.Contains(safeMethods, r.Method) {
h.next.ServeHTTP(w, r)
return
}
req := r
// If original request was HTTPS (via proxy), keep CSRF checks.
if !isHTTPSRequest(r) {
req = csrf.PlaintextHTTPRequest(r)
}
h.next.ServeHTTP(w, req)
}
func PlaintextHTTPRequest(next http.Handler) http.Handler {
return &plainTextHTTPRequestHandler{next: next}
}
@@ -0,0 +1,173 @@
package middlewares
import (
"testing"
)
var tests = []struct {
name string
forwarded string
expected string
}{
{
name: "empty header",
forwarded: "",
expected: "",
},
{
name: "single proxy with proto=https",
forwarded: "proto=https",
expected: "https",
},
{
name: "single proxy with proto=http",
forwarded: "proto=http",
expected: "http",
},
{
name: "single proxy with multiple directives",
forwarded: "for=192.0.2.60;proto=https;by=203.0.113.43",
expected: "https",
},
{
name: "single proxy with proto in middle",
forwarded: "for=192.0.2.60;proto=https;host=example.com",
expected: "https",
},
{
name: "single proxy with proto at end",
forwarded: "for=192.0.2.60;host=example.com;proto=https",
expected: "https",
},
{
name: "multiple proxies - takes first",
forwarded: "proto=https, proto=http",
expected: "https",
},
{
name: "multiple proxies with complex format",
forwarded: "for=192.0.2.43;proto=https, for=198.51.100.17;proto=http",
expected: "https",
},
{
name: "multiple proxies with for directive only",
forwarded: "for=192.0.2.43, for=198.51.100.17",
expected: "",
},
{
name: "multiple proxies with proto only in second",
forwarded: "for=192.0.2.43, proto=https",
expected: "",
},
{
name: "multiple proxies with proto only in first",
forwarded: "proto=https, for=198.51.100.17",
expected: "https",
},
{
name: "quoted protocol value",
forwarded: "proto=\"https\"",
expected: "https",
},
{
name: "single quoted protocol value",
forwarded: "proto='https'",
expected: "https",
},
{
name: "mixed case protocol",
forwarded: "proto=HTTPS",
expected: "HTTPS",
},
{
name: "no proto directive",
forwarded: "for=192.0.2.60;by=203.0.113.43",
expected: "",
},
{
name: "empty proto value",
forwarded: "proto=",
expected: "",
},
{
name: "whitespace around values",
forwarded: " proto = https ",
expected: "https",
},
{
name: "whitespace around semicolons",
forwarded: "for=192.0.2.60 ; proto=https ; by=203.0.113.43",
expected: "https",
},
{
name: "whitespace around commas",
forwarded: "proto=https , proto=http",
expected: "https",
},
{
name: "IPv6 address in for directive",
forwarded: "for=\"[2001:db8:cafe::17]:4711\";proto=https",
expected: "https",
},
{
name: "complex multiple proxies with IPv6",
forwarded: "for=192.0.2.43;proto=https, for=\"[2001:db8:cafe::17]\";proto=http",
expected: "https",
},
{
name: "obfuscated identifiers",
forwarded: "for=_mdn;proto=https",
expected: "https",
},
{
name: "unknown identifier",
forwarded: "for=unknown;proto=https",
expected: "https",
},
{
name: "malformed key-value pair",
forwarded: "proto",
expected: "",
},
{
name: "malformed key-value pair with equals",
forwarded: "proto=",
expected: "",
},
{
name: "multiple equals signs",
forwarded: "proto=https=extra",
expected: "https=extra",
},
{
name: "mixed case directive name",
forwarded: "PROTO=https",
expected: "https",
},
{
name: "mixed case directive name with spaces",
forwarded: " Proto = https ",
expected: "https",
},
}
func TestParseForwardedHeaderProto(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseForwardedHeaderProto(tt.forwarded)
if result != tt.expected {
t.Errorf("parseForwardedHeader(%q) = %q, want %q", tt.forwarded, result, tt.expected)
}
})
}
}
func FuzzParseForwardedHeaderProto(f *testing.F) {
for _, t := range tests {
f.Add(t.forwarded)
}
f.Fuzz(func(t *testing.T, forwarded string) {
parseForwardedHeaderProto(forwarded)
})
}
+27 -2
View File
@@ -7,12 +7,31 @@ import (
"strings"
)
// Note that we discard any non-canonical headers by design
var allowedHeaders = map[string]struct{}{
"Accept": {},
"Accept-Encoding": {},
"Accept-Language": {},
"Cache-Control": {},
"Content-Length": {},
"Content-Type": {},
"Private-Token": {},
"User-Agent": {},
"X-Portaineragent-Target": {},
"X-Portainer-Volumename": {},
"X-Registry-Auth": {},
}
// newSingleHostReverseProxyWithHostHeader is based on NewSingleHostReverseProxy
// from golang.org/src/net/http/httputil/reverseproxy.go and merely sets the Host
// HTTP header, which NewSingleHostReverseProxy deliberately preserves.
func newSingleHostReverseProxyWithHostHeader(target *url.URL) *httputil.ReverseProxy {
return &httputil.ReverseProxy{Director: createDirector(target)}
}
func createDirector(target *url.URL) func(*http.Request) {
targetQuery := target.RawQuery
director := func(req *http.Request) {
return func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
@@ -26,8 +45,14 @@ func newSingleHostReverseProxyWithHostHeader(target *url.URL) *httputil.ReverseP
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
for k := range req.Header {
if _, ok := allowedHeaders[k]; !ok {
// We use delete here instead of req.Header.Del because we want to delete non canonical headers.
delete(req.Header, k)
}
}
}
return &httputil.ReverseProxy{Director: director}
}
// singleJoiningSlash from golang.org/src/net/http/httputil/reverseproxy.go
@@ -0,0 +1,190 @@
package factory
import (
"net/http"
"net/url"
"testing"
"github.com/google/go-cmp/cmp"
portainer "github.com/portainer/portainer/api"
)
func Test_createDirector(t *testing.T) {
testCases := []struct {
name string
target *url.URL
req *http.Request
expectedReq *http.Request
}{
{
name: "base case",
target: createURL(t, "https://portainer.io/api/docker?a=5&b=6"),
req: createRequest(
t,
"GET",
"https://agent-portainer.io/test?c=7",
map[string]string{"Accept-Encoding": "gzip", "Accept": "application/json", "User-Agent": "something"},
true,
),
expectedReq: createRequest(
t,
"GET",
"https://portainer.io/api/docker/test?a=5&b=6&c=7",
map[string]string{"Accept-Encoding": "gzip", "Accept": "application/json", "User-Agent": "something"},
true,
),
},
{
name: "no User-Agent",
target: createURL(t, "https://portainer.io/api/docker?a=5&b=6"),
req: createRequest(
t,
"GET",
"https://agent-portainer.io/test?c=7",
map[string]string{"Accept-Encoding": "gzip", "Accept": "application/json"},
true,
),
expectedReq: createRequest(
t,
"GET",
"https://portainer.io/api/docker/test?a=5&b=6&c=7",
map[string]string{"Accept-Encoding": "gzip", "Accept": "application/json", "User-Agent": ""},
true,
),
},
{
name: "Sensitive Headers",
target: createURL(t, "https://portainer.io/api/docker?a=5&b=6"),
req: createRequest(
t,
"GET",
"https://agent-portainer.io/test?c=7",
map[string]string{
"Authorization": "secret",
"Proxy-Authorization": "secret",
"Cookie": "secret",
"X-Csrf-Token": "secret",
"X-Api-Key": "secret",
"Accept": "application/json",
"Accept-Encoding": "gzip",
"Accept-Language": "en-GB",
"Cache-Control": "None",
"Content-Length": "100",
"Content-Type": "application/json",
"Private-Token": "test-private-token",
"User-Agent": "test-user-agent",
"X-Portaineragent-Target": "test-agent-1",
"X-Portainer-Volumename": "test-volume-1",
"X-Registry-Auth": "test-registry-auth",
},
true,
),
expectedReq: createRequest(
t,
"GET",
"https://portainer.io/api/docker/test?a=5&b=6&c=7",
map[string]string{
"Accept": "application/json",
"Accept-Encoding": "gzip",
"Accept-Language": "en-GB",
"Cache-Control": "None",
"Content-Length": "100",
"Content-Type": "application/json",
"Private-Token": "test-private-token",
"User-Agent": "test-user-agent",
"X-Portaineragent-Target": "test-agent-1",
"X-Portainer-Volumename": "test-volume-1",
"X-Registry-Auth": "test-registry-auth",
},
true,
),
},
{
name: "Non canonical Headers",
target: createURL(t, "https://portainer.io/api/docker?a=5&b=6"),
req: createRequest(
t,
"GET",
"https://agent-portainer.io/test?c=7",
map[string]string{
"Accept": "application/json",
"Accept-Encoding": "gzip",
"Accept-Language": "en-GB",
"Cache-Control": "None",
"Content-Length": "100",
"Content-Type": "application/json",
"Private-Token": "test-private-token",
"User-Agent": "test-user-agent",
portainer.PortainerAgentTargetHeader: "test-agent-1",
"X-Portainer-VolumeName": "test-volume-1",
"X-Registry-Auth": "test-registry-auth",
},
false,
),
expectedReq: createRequest(
t,
"GET",
"https://portainer.io/api/docker/test?a=5&b=6&c=7",
map[string]string{
"Accept": "application/json",
"Accept-Encoding": "gzip",
"Accept-Language": "en-GB",
"Cache-Control": "None",
"Content-Length": "100",
"Content-Type": "application/json",
"Private-Token": "test-private-token",
"User-Agent": "test-user-agent",
"X-Registry-Auth": "test-registry-auth",
},
true,
),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
director := createDirector(tc.target)
director(tc.req)
if diff := cmp.Diff(tc.req, tc.expectedReq, cmp.Comparer(compareRequests)); diff != "" {
t.Fatalf("requests are different: \n%s", diff)
}
})
}
}
func createURL(t *testing.T, urlString string) *url.URL {
parsedURL, err := url.Parse(urlString)
if err != nil {
t.Fatalf("Failed to create url: %s", err)
}
return parsedURL
}
func createRequest(t *testing.T, method, url string, headers map[string]string, canonicalHeaders bool) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
t.Fatalf("Failed to create http request: %s", err)
} else {
for k, v := range headers {
if canonicalHeaders {
req.Header.Add(k, v)
} else {
req.Header[k] = []string{v}
}
}
}
return req
}
func compareRequests(a, b *http.Request) bool {
methodEqual := a.Method == b.Method
urlEqual := cmp.Diff(a.URL, b.URL) == ""
hostEqual := a.Host == b.Host
protoEqual := a.Proto == b.Proto && a.ProtoMajor == b.ProtoMajor && a.ProtoMinor == b.ProtoMinor
headersEqual := cmp.Diff(a.Header, b.Header) == ""
return methodEqual && urlEqual && hostEqual && protoEqual && headersEqual
}
+5 -2
View File
@@ -112,6 +112,8 @@ type Server struct {
AdminCreationDone chan struct{}
PendingActionsService *pendingactions.PendingActionsService
PlatformService platform.Service
PullLimitCheckDisabled bool
TrustedOrigins []string
}
// Start starts the HTTP server
@@ -181,6 +183,7 @@ func (server *Server) Start() error {
endpointHandler.BindAddress = server.BindAddress
endpointHandler.BindAddressHTTPS = server.BindAddressHTTPS
endpointHandler.PendingActionsService = server.PendingActionsService
endpointHandler.PullLimitCheckDisabled = server.PullLimitCheckDisabled
var endpointEdgeHandler = endpointedge.NewHandler(requestBouncer, server.DataStore, server.FileService, server.ReverseTunnelService)
@@ -337,7 +340,7 @@ func (server *Server) Start() error {
handler = middlewares.WithSlowRequestsLogger(handler)
handler, err := csrf.WithProtect(handler)
handler, err := csrf.WithProtect(handler, server.TrustedOrigins)
if err != nil {
return errors.Wrap(err, "failed to create CSRF middleware")
}
@@ -347,7 +350,7 @@ func (server *Server) Start() error {
log.Info().Str("bind_address", server.BindAddress).Msg("starting HTTP server")
httpServer := &http.Server{
Addr: server.BindAddress,
Handler: handler,
Handler: middlewares.PlaintextHTTPRequest(handler),
ErrorLog: errorLogger,
}
+8
View File
@@ -87,6 +87,14 @@ func (factory *ClientFactory) ClearClientCache() {
// Remove the cached kube client so a new one can be created
func (factory *ClientFactory) RemoveKubeClient(endpointID portainer.EndpointID) {
factory.endpointProxyClients.Delete(strconv.Itoa(int(endpointID)))
endpointPrefix := strconv.Itoa(int(endpointID)) + "."
for key := range factory.endpointProxyClients.Items() {
if strings.HasPrefix(key, endpointPrefix) {
factory.endpointProxyClients.Delete(key)
}
}
}
// GetPrivilegedKubeClient checks if an existing client is already registered for the environment(endpoint) and returns it if one is found.
+12 -1
View File
@@ -134,6 +134,8 @@ type (
LogLevel *string
LogMode *string
KubectlShellImage *string
PullLimitCheckDisabled *bool
TrustedOrigins *string
}
// CustomTemplateVariableDefinition
@@ -1637,7 +1639,7 @@ type (
const (
// APIVersion is the version number of the Portainer API
APIVersion = "2.27.4"
APIVersion = "2.27.9"
// Support annotation for the API version ("STS" for Short-Term Support or "LTS" for Long-Term Support)
APIVersionSupport = "LTS"
// Edition is what this edition of Portainer is called
@@ -1689,6 +1691,15 @@ const (
PortainerCacheHeader = "X-Portainer-Cache"
// KubectlShellImageEnvVar is the environment variable used to override the default kubectl shell image
KubectlShellImageEnvVar = "KUBECTL_SHELL_IMAGE"
// PullLimitCheckDisabledEnvVar is the environment variable used to disable the pull limit check
PullLimitCheckDisabledEnvVar = "PULL_LIMIT_CHECK_DISABLED"
// LicenseServerBaseURL represents the base URL of the API used to validate
// an extension license.
LicenseServerBaseURL = "https://api.portainer.io"
// URL to validate licenses along with system metadata.
LicenseCheckInURL = LicenseServerBaseURL + "/licenses/checkin"
// TrustedOriginsEnvVar is the environment variable used to set the trusted origins for CSRF protection
TrustedOriginsEnvVar = "TRUSTED_ORIGINS"
)
// List of supported features
+3 -3
View File
@@ -1,6 +1,6 @@
module github.com/portainer/portainer
go 1.23.5
go 1.23.10
require (
github.com/Masterminds/semver v1.5.0
@@ -27,7 +27,7 @@ require (
github.com/gofrs/uuid v4.2.0+incompatible
github.com/golang-jwt/jwt/v4 v4.5.2
github.com/google/go-cmp v0.6.0
github.com/gorilla/csrf v1.7.2
github.com/gorilla/csrf v1.7.3
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.0
github.com/hashicorp/go-version v1.7.0
@@ -242,7 +242,7 @@ require (
go.opentelemetry.io/otel/trace v1.25.0 // indirect
go.opentelemetry.io/proto/otlp v1.1.0 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/term v0.30.0 // indirect
golang.org/x/text v0.23.0 // indirect
+4 -4
View File
@@ -311,8 +311,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI=
github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/mux v1.7.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
@@ -715,8 +715,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M=
golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+1 -1
View File
@@ -2,7 +2,7 @@
"author": "Portainer.io",
"name": "portainer",
"homepage": "http://portainer.io",
"version": "2.27.4",
"version": "2.27.9",
"repository": {
"type": "git",
"url": "git@github.com:portainer/portainer.git"
+111
View File
@@ -0,0 +1,111 @@
package validate
import (
"net"
"net/url"
"regexp"
"strings"
"unicode/utf8"
"github.com/google/uuid"
)
var (
hexadecimalRegex = regexp.MustCompile(`^[0-9a-fA-F]+$`)
dnsNameRegex = regexp.MustCompile(`^([a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62}){1}(\.[a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62})*[\._]?$`)
)
func IsURL(urlString string) bool {
if len(urlString) == 0 {
return false
}
strTemp := urlString
if !strings.Contains(urlString, "://") {
// support no indicated urlscheme
// http:// is appended so url.Parse will succeed
strTemp = "http://" + urlString
}
u, err := url.Parse(strTemp)
return err == nil && u.Host != ""
}
func IsUUID(uuidString string) bool {
return uuid.Validate(uuidString) == nil
}
func IsHexadecimal(hexString string) bool {
return hexadecimalRegex.MatchString(hexString)
}
func HasWhitespaceOnly(s string) bool {
return len(s) > 0 && strings.TrimSpace(s) == ""
}
func MinStringLength(s string, len int) bool {
return utf8.RuneCountInString(s) >= len
}
func Matches(s, pattern string) bool {
match, err := regexp.MatchString(pattern, s)
return err == nil && match
}
func IsNonPositive(f float64) bool {
return f <= 0
}
func InRange(val, left, right float64) bool {
if left > right {
left, right = right, left
}
return val >= left && val <= right
}
func IsHost(s string) bool {
return IsIP(s) || IsDNSName(s)
}
func IsIP(s string) bool {
return net.ParseIP(s) != nil
}
func IsDNSName(s string) bool {
if s == "" || len(strings.ReplaceAll(s, ".", "")) > 255 {
// constraints already violated
return false
}
return !IsIP(s) && dnsNameRegex.MatchString(s)
}
func IsTrustedOrigin(s string) bool {
// Reject if a scheme is present
if strings.Contains(s, "://") {
return false
}
// Prepend http:// for parsing
strTemp := "http://" + s
parsedOrigin, err := url.Parse(strTemp)
if err != nil {
return false
}
// Validate host, and ensure no user, path, query, fragment, port, etc.
if parsedOrigin.Host == "" ||
parsedOrigin.User != nil ||
parsedOrigin.Path != "" ||
parsedOrigin.RawQuery != "" ||
parsedOrigin.Fragment != "" ||
parsedOrigin.Opaque != "" ||
parsedOrigin.RawFragment != "" ||
parsedOrigin.RawPath != "" ||
parsedOrigin.Port() != "" {
return false
}
return true
}
+500
View File
@@ -0,0 +1,500 @@
package validate
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_IsURL(t *testing.T) {
testCases := []struct {
name string
url string
expectedResult bool
}{
{
name: "simple url",
url: "https://google.com",
expectedResult: true,
},
{
name: "empty",
url: "",
expectedResult: false,
},
{
name: "no schema",
url: "google.com",
expectedResult: true,
},
{
name: "path",
url: "https://google.com/some/thing",
expectedResult: true,
},
{
name: "query params",
url: "https://google.com/some/thing?a=5&b=6",
expectedResult: true,
},
{
name: "no top level domain",
url: "google",
expectedResult: true,
},
{
name: "Unicode URL",
url: "www.xn--exampe-7db.ai",
expectedResult: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsURL(tc.url)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsUUID(t *testing.T) {
testCases := []struct {
name string
uuid string
expectedResult bool
}{
{
name: "empty",
uuid: "",
expectedResult: false,
},
{
name: "version 3 UUID",
uuid: "060507eb-3b9a-362e-b850-d5f065eea403",
expectedResult: true,
},
{
name: "version 4 UUID",
uuid: "63e695ee-48a9-498a-98b3-9472ff75e09f",
expectedResult: true,
},
{
name: "version 5 UUID",
uuid: "5daabcd8-f17e-568c-aa6f-da9d92c7032c",
expectedResult: true,
},
{
name: "text",
uuid: "something like this",
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsUUID(tc.uuid)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsHexadecimal(t *testing.T) {
testCases := []struct {
name string
hex string
expectedResult bool
}{
{
name: "empty",
hex: "",
expectedResult: false,
},
{
name: "hex",
hex: "48656C6C6F20736F6D657468696E67",
expectedResult: true,
},
{
name: "text",
hex: "something like this",
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsHexadecimal(tc.hex)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_HasWhitespaceOnly(t *testing.T) {
testCases := []struct {
name string
s string
expectedResult bool
}{
{
name: "empty",
s: "",
expectedResult: false,
},
{
name: "space",
s: " ",
expectedResult: true,
},
{
name: "tab",
s: "\t",
expectedResult: true,
},
{
name: "text",
s: "something like this",
expectedResult: false,
},
{
name: "all whitespace",
s: "\t\n\v\f\r ",
expectedResult: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := HasWhitespaceOnly(tc.s)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_MinStringLength(t *testing.T) {
testCases := []struct {
name string
s string
len int
expectedResult bool
}{
{
name: "empty + zero len",
s: "",
len: 0,
expectedResult: true,
},
{
name: "empty + non zero len",
s: "",
len: 10,
expectedResult: false,
},
{
name: "long text + non zero len",
s: "something else",
len: 10,
expectedResult: true,
},
{
name: "multibyte characters - enough",
s: "X生",
len: 2,
expectedResult: true,
},
{
name: "multibyte characters - not enough",
s: "X生",
len: 3,
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := MinStringLength(tc.s, tc.len)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_Matches(t *testing.T) {
testCases := []struct {
name string
s string
pattern string
expectedResult bool
}{
{
name: "empty",
s: "",
pattern: "",
expectedResult: true,
},
{
name: "space",
s: "something else",
pattern: " ",
expectedResult: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := Matches(tc.s, tc.pattern)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsNonPositive(t *testing.T) {
testCases := []struct {
name string
f float64
expectedResult bool
}{
{
name: "zero",
f: 0,
expectedResult: true,
},
{
name: "positive",
f: 1,
expectedResult: false,
},
{
name: "negative",
f: -1,
expectedResult: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsNonPositive(tc.f)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_InRange(t *testing.T) {
testCases := []struct {
name string
f float64
left float64
right float64
expectedResult bool
}{
{
name: "zero",
f: 0,
left: 0,
right: 0,
expectedResult: true,
},
{
name: "equal left",
f: 1,
left: 1,
right: 2,
expectedResult: true,
},
{
name: "equal right",
f: 2,
left: 1,
right: 2,
expectedResult: true,
},
{
name: "above",
f: 3,
left: 1,
right: 2,
expectedResult: false,
},
{
name: "below",
f: 0,
left: 1,
right: 2,
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := InRange(tc.f, tc.left, tc.right)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsHost(t *testing.T) {
testCases := []struct {
name string
s string
expectedResult bool
}{
{
name: "empty",
s: "",
expectedResult: false,
},
{
name: "ip address",
s: "192.168.1.1",
expectedResult: true,
},
{
name: "hostname",
s: "google.com",
expectedResult: true,
},
{
name: "text",
s: "Something like this",
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsHost(tc.s)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsIP(t *testing.T) {
testCases := []struct {
name string
s string
expectedResult bool
}{
{
name: "empty",
s: "",
expectedResult: false,
},
{
name: "ip address",
s: "192.168.1.1",
expectedResult: true,
},
{
name: "hostname",
s: "google.com",
expectedResult: false,
},
{
name: "text",
s: "Something like this",
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsIP(tc.s)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsDNSName(t *testing.T) {
testCases := []struct {
name string
s string
expectedResult bool
}{
{
name: "empty",
s: "",
expectedResult: false,
},
{
name: "ip address",
s: "192.168.1.1",
expectedResult: false,
},
{
name: "hostname",
s: "google.com",
expectedResult: true,
},
{
name: "text",
s: "Something like this",
expectedResult: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := IsDNSName(tc.s)
require.Equal(t, tc.expectedResult, result)
})
}
}
func Test_IsTrustedOrigin(t *testing.T) {
f := func(s string, expected bool) {
t.Helper()
result := IsTrustedOrigin(s)
if result != expected {
t.Fatalf("unexpected result for %q; got %t; want %t", s, result, expected)
}
}
// Valid trusted origins - host only
f("localhost", true)
f("example.com", true)
f("192.168.1.1", true)
f("api.example.com", true)
f("subdomain.example.org", true)
// Invalid trusted origins - host with port (no longer allowed)
f("localhost:8080", false)
f("example.com:3000", false)
f("192.168.1.1:443", false)
f("api.example.com:9000", false)
// Invalid trusted origins - empty or malformed
f("", false)
f("invalid url", false)
f("://example.com", false)
// Invalid trusted origins - with scheme
f("http://example.com", false)
f("https://localhost", false)
f("ftp://192.168.1.1", false)
// Invalid trusted origins - with user info
f("user@example.com", false)
f("user:pass@localhost", false)
// Invalid trusted origins - with path
f("example.com/path", false)
f("localhost/api", false)
f("192.168.1.1/static", false)
// Invalid trusted origins - with query parameters
f("example.com?param=value", false)
f("localhost:8080?query=test", false)
// Invalid trusted origins - with fragment
f("example.com#fragment", false)
f("localhost:3000#section", false)
// Invalid trusted origins - with multiple invalid components
f("https://user@example.com/path?query=value#fragment", false)
f("http://localhost:8080/api/v1?param=test", false)
// Edge cases - ports are no longer allowed
f("example.com:0", false) // port 0 is no longer valid
f("example.com:65535", false) // max port number is no longer valid
f("example.com:99999", false) // invalid port number
f("example.com:-1", false) // negative port
}