This commit is contained in:
MHSanaei 2026-04-19 23:55:01 +02:00
parent f9fe4b2e5e
commit 554202e7d1
No known key found for this signature in database
GPG key ID: 7E4060F2FBE5AB7A
3 changed files with 54 additions and 34 deletions

View file

@ -67,7 +67,7 @@ func mapCustomGeoErr(c *gin.Context, err error) error {
return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost")) return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost"))
case errors.Is(err, service.ErrCustomGeoPathTraversal): case errors.Is(err, service.ErrCustomGeoPathTraversal):
logger.Warning("custom geo path traversal blocked:", err) logger.Warning("custom geo path traversal blocked:", err)
return errors.New(I18nWeb(c, "pages.index.customGeoErrAliasPattern")) return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
default: default:
return err return err
} }

View file

@ -129,7 +129,7 @@ func (s *CustomGeoService) sanitizeURL(raw string) (string, error) {
if u.Host == "" { if u.Host == "" {
return "", ErrCustomGeoURLHost return "", ErrCustomGeoURLHost
} }
if err := checkSSRF(u.Hostname()); err != nil { if err := checkSSRF(context.Background(), u.Hostname()); err != nil {
return "", err return "", err
} }
// Reconstruct URL from parsed components to break taint propagation. // Reconstruct URL from parsed components to break taint propagation.
@ -165,14 +165,15 @@ func isBlockedIP(ip net.IP) bool {
} }
// checkSSRFDefault validates that the given host does not resolve to a private/internal IP. // checkSSRFDefault validates that the given host does not resolve to a private/internal IP.
func checkSSRFDefault(hostname string) error { // It is context-aware so that dial context cancellation/deadlines are respected during DNS resolution.
ips, err := net.LookupIP(hostname) func checkSSRFDefault(ctx context.Context, hostname string) error {
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
if err != nil { if err != nil {
return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname) return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname)
} }
for _, ip := range ips { for _, ipAddr := range ips {
if isBlockedIP(ip) { if isBlockedIP(ipAddr.IP) {
return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ip) return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ipAddr.IP)
} }
} }
return nil return nil
@ -181,20 +182,24 @@ func checkSSRFDefault(hostname string) error {
// checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers. // checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers.
var checkSSRF = checkSSRFDefault var checkSSRF = checkSSRFDefault
func ssrfSafeTransport() *http.Transport { func ssrfSafeTransport() http.RoundTripper {
return &http.Transport{ base, ok := http.DefaultTransport.(*http.Transport)
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if !ok {
host, _, err := net.SplitHostPort(addr) base = &http.Transport{}
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
}
if err := checkSSRF(host); err != nil {
return nil, err
}
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
},
} }
cloned := base.Clone()
cloned.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
}
if err := checkSSRF(ctx, host); err != nil {
return nil, err
}
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
}
return cloned
} }
func probeCustomGeoURLWithGET(rawURL string) error { func probeCustomGeoURLWithGET(rawURL string) error {
@ -287,25 +292,38 @@ func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModi
return false, lm, nil return false, lm, nil
} }
// validateDestPath ensures destPath is inside the bin folder, preventing path traversal. // sanitizeDestPath ensures destPath is inside the bin folder, preventing path traversal.
func validateDestPath(destPath string) error { // It resolves symlinks to prevent symlink-based escapes.
// Returns the cleaned absolute path that is safe to use in file operations.
func sanitizeDestPath(destPath string) (string, error) {
baseDirAbs, err := filepath.Abs(config.GetBinFolderPath()) baseDirAbs, err := filepath.Abs(config.GetBinFolderPath())
if err != nil { if err != nil {
return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err) return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
}
// Resolve symlinks in base directory to get the real path.
if resolved, evalErr := filepath.EvalSymlinks(baseDirAbs); evalErr == nil {
baseDirAbs = resolved
} }
destPathAbs, err := filepath.Abs(destPath) destPathAbs, err := filepath.Abs(destPath)
if err != nil { if err != nil {
return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err) return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
} }
relToBase, err := filepath.Rel(baseDirAbs, destPathAbs) // Resolve symlinks for the parent directory of the destination path.
if err != nil || strings.HasPrefix(relToBase, "..") || filepath.IsAbs(relToBase) { destDir := filepath.Dir(destPathAbs)
return ErrCustomGeoPathTraversal if resolved, evalErr := filepath.EvalSymlinks(destDir); evalErr == nil {
destPathAbs = filepath.Join(resolved, filepath.Base(destPathAbs))
} }
return nil // Verify the resolved path is within the safe base directory using prefix check.
safeDirPrefix := baseDirAbs + string(filepath.Separator)
if !strings.HasPrefix(destPathAbs, safeDirPrefix) {
return "", ErrCustomGeoPathTraversal
}
return destPathAbs, nil
} }
func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) { func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
if err := validateDestPath(destPath); err != nil { destPath, err = sanitizeDestPath(destPath)
if err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
} }
@ -477,7 +495,8 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
s.syncLocalPath(r) s.syncLocalPath(r)
r.Id = id r.Id = id
r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias)) r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
if err := validateDestPath(r.LocalPath); err != nil { r.LocalPath, err = sanitizeDestPath(r.LocalPath)
if err != nil {
return err return err
} }
if oldPath != r.LocalPath && oldPath != "" { if oldPath != r.LocalPath && oldPath != "" {
@ -518,9 +537,9 @@ func (s *CustomGeoService) Delete(id int) (displayName string, err error) {
return "", err return "", err
} }
displayName = s.fileNameFor(r.Type, r.Alias) displayName = s.fileNameFor(r.Type, r.Alias)
p := s.resolveDestPath(&r) p, pathErr := sanitizeDestPath(s.resolveDestPath(&r))
if err := validateDestPath(p); err != nil { if pathErr != nil {
return displayName, err return displayName, pathErr
} }
if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil { if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
return displayName, err return displayName, err

View file

@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -17,7 +18,7 @@ import (
func disableSSRFCheck(t *testing.T) { func disableSSRFCheck(t *testing.T) {
t.Helper() t.Helper()
orig := checkSSRF orig := checkSSRF
checkSSRF = func(string) error { return nil } checkSSRF = func(_ context.Context, _ string) error { return nil }
t.Cleanup(func() { checkSSRF = orig }) t.Cleanup(func() { checkSSRF = orig })
} }