From 554202e7d14111cbace775f8638d9344a4268b94 Mon Sep 17 00:00:00 2001 From: MHSanaei Date: Sun, 19 Apr 2026 23:55:01 +0200 Subject: [PATCH] fix --- web/controller/custom_geo.go | 2 +- web/service/custom_geo.go | 83 +++++++++++++++++++++------------- web/service/custom_geo_test.go | 3 +- 3 files changed, 54 insertions(+), 34 deletions(-) diff --git a/web/controller/custom_geo.go b/web/controller/custom_geo.go index 91c38ed7..677bda99 100644 --- a/web/controller/custom_geo.go +++ b/web/controller/custom_geo.go @@ -67,7 +67,7 @@ func mapCustomGeoErr(c *gin.Context, err error) error { return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost")) case errors.Is(err, service.ErrCustomGeoPathTraversal): logger.Warning("custom geo path traversal blocked:", err) - return errors.New(I18nWeb(c, "pages.index.customGeoErrAliasPattern")) + return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload")) default: return err } diff --git a/web/service/custom_geo.go b/web/service/custom_geo.go index d60ee093..e69d9f1d 100644 --- a/web/service/custom_geo.go +++ b/web/service/custom_geo.go @@ -129,7 +129,7 @@ func (s *CustomGeoService) sanitizeURL(raw string) (string, error) { if u.Host == "" { return "", ErrCustomGeoURLHost } - if err := checkSSRF(u.Hostname()); err != nil { + if err := checkSSRF(context.Background(), u.Hostname()); err != nil { return "", err } // Reconstruct URL from parsed components to break taint propagation. @@ -165,14 +165,15 @@ func isBlockedIP(ip net.IP) bool { } // checkSSRFDefault validates that the given host does not resolve to a private/internal IP. -func checkSSRFDefault(hostname string) error { - ips, err := net.LookupIP(hostname) +// It is context-aware so that dial context cancellation/deadlines are respected during DNS resolution. +func checkSSRFDefault(ctx context.Context, hostname string) error { + ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname) if err != nil { return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname) } - for _, ip := range ips { - if isBlockedIP(ip) { - return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ip) + for _, ipAddr := range ips { + if isBlockedIP(ipAddr.IP) { + return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ipAddr.IP) } } return nil @@ -181,20 +182,24 @@ func checkSSRFDefault(hostname string) error { // checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers. var checkSSRF = checkSSRFDefault -func ssrfSafeTransport() *http.Transport { - return &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err) - } - if err := checkSSRF(host); err != nil { - return nil, err - } - var dialer net.Dialer - return dialer.DialContext(ctx, network, addr) - }, +func ssrfSafeTransport() http.RoundTripper { + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + base = &http.Transport{} } + cloned := base.Clone() + cloned.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err) + } + if err := checkSSRF(ctx, host); err != nil { + return nil, err + } + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + } + return cloned } func probeCustomGeoURLWithGET(rawURL string) error { @@ -287,25 +292,38 @@ func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModi return false, lm, nil } -// validateDestPath ensures destPath is inside the bin folder, preventing path traversal. -func validateDestPath(destPath string) error { +// sanitizeDestPath ensures destPath is inside the bin folder, preventing path traversal. +// It resolves symlinks to prevent symlink-based escapes. +// Returns the cleaned absolute path that is safe to use in file operations. +func sanitizeDestPath(destPath string) (string, error) { baseDirAbs, err := filepath.Abs(config.GetBinFolderPath()) if err != nil { - return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err) + return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err) + } + // Resolve symlinks in base directory to get the real path. + if resolved, evalErr := filepath.EvalSymlinks(baseDirAbs); evalErr == nil { + baseDirAbs = resolved } destPathAbs, err := filepath.Abs(destPath) if err != nil { - return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err) + return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err) } - relToBase, err := filepath.Rel(baseDirAbs, destPathAbs) - if err != nil || strings.HasPrefix(relToBase, "..") || filepath.IsAbs(relToBase) { - return ErrCustomGeoPathTraversal + // Resolve symlinks for the parent directory of the destination path. + destDir := filepath.Dir(destPathAbs) + if resolved, evalErr := filepath.EvalSymlinks(destDir); evalErr == nil { + destPathAbs = filepath.Join(resolved, filepath.Base(destPathAbs)) } - return nil + // Verify the resolved path is within the safe base directory using prefix check. + safeDirPrefix := baseDirAbs + string(filepath.Separator) + if !strings.HasPrefix(destPathAbs, safeDirPrefix) { + return "", ErrCustomGeoPathTraversal + } + return destPathAbs, nil } func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) { - if err := validateDestPath(destPath); err != nil { + destPath, err = sanitizeDestPath(destPath) + if err != nil { return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) } @@ -477,7 +495,8 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error { s.syncLocalPath(r) r.Id = id r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias)) - if err := validateDestPath(r.LocalPath); err != nil { + r.LocalPath, err = sanitizeDestPath(r.LocalPath) + if err != nil { return err } if oldPath != r.LocalPath && oldPath != "" { @@ -518,9 +537,9 @@ func (s *CustomGeoService) Delete(id int) (displayName string, err error) { return "", err } displayName = s.fileNameFor(r.Type, r.Alias) - p := s.resolveDestPath(&r) - if err := validateDestPath(p); err != nil { - return displayName, err + p, pathErr := sanitizeDestPath(s.resolveDestPath(&r)) + if pathErr != nil { + return displayName, pathErr } if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil { return displayName, err diff --git a/web/service/custom_geo_test.go b/web/service/custom_geo_test.go index 1c60c54a..731c4fc6 100644 --- a/web/service/custom_geo_test.go +++ b/web/service/custom_geo_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "fmt" "net/http" @@ -17,7 +18,7 @@ import ( func disableSSRFCheck(t *testing.T) { t.Helper() orig := checkSSRF - checkSSRF = func(string) error { return nil } + checkSSRF = func(_ context.Context, _ string) error { return nil } t.Cleanup(func() { checkSSRF = orig }) }