This commit is contained in:
MHSanaei 2026-04-19 23:55:01 +02:00
parent f9fe4b2e5e
commit 554202e7d1
No known key found for this signature in database
GPG key ID: 7E4060F2FBE5AB7A
3 changed files with 54 additions and 34 deletions

View file

@ -67,7 +67,7 @@ func mapCustomGeoErr(c *gin.Context, err error) error {
return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost"))
case errors.Is(err, service.ErrCustomGeoPathTraversal):
logger.Warning("custom geo path traversal blocked:", err)
return errors.New(I18nWeb(c, "pages.index.customGeoErrAliasPattern"))
return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
default:
return err
}

View file

@ -129,7 +129,7 @@ func (s *CustomGeoService) sanitizeURL(raw string) (string, error) {
if u.Host == "" {
return "", ErrCustomGeoURLHost
}
if err := checkSSRF(u.Hostname()); err != nil {
if err := checkSSRF(context.Background(), u.Hostname()); err != nil {
return "", err
}
// Reconstruct URL from parsed components to break taint propagation.
@ -165,14 +165,15 @@ func isBlockedIP(ip net.IP) bool {
}
// checkSSRFDefault validates that the given host does not resolve to a private/internal IP.
func checkSSRFDefault(hostname string) error {
ips, err := net.LookupIP(hostname)
// It is context-aware so that dial context cancellation/deadlines are respected during DNS resolution.
func checkSSRFDefault(ctx context.Context, hostname string) error {
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
if err != nil {
return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname)
}
for _, ip := range ips {
if isBlockedIP(ip) {
return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ip)
for _, ipAddr := range ips {
if isBlockedIP(ipAddr.IP) {
return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ipAddr.IP)
}
}
return nil
@ -181,20 +182,24 @@ func checkSSRFDefault(hostname string) error {
// checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers.
var checkSSRF = checkSSRFDefault
func ssrfSafeTransport() *http.Transport {
return &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
}
if err := checkSSRF(host); err != nil {
return nil, err
}
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
},
func ssrfSafeTransport() http.RoundTripper {
base, ok := http.DefaultTransport.(*http.Transport)
if !ok {
base = &http.Transport{}
}
cloned := base.Clone()
cloned.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
}
if err := checkSSRF(ctx, host); err != nil {
return nil, err
}
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
}
return cloned
}
func probeCustomGeoURLWithGET(rawURL string) error {
@ -287,25 +292,38 @@ func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModi
return false, lm, nil
}
// validateDestPath ensures destPath is inside the bin folder, preventing path traversal.
func validateDestPath(destPath string) error {
// sanitizeDestPath ensures destPath is inside the bin folder, preventing path traversal.
// It resolves symlinks to prevent symlink-based escapes.
// Returns the cleaned absolute path that is safe to use in file operations.
func sanitizeDestPath(destPath string) (string, error) {
baseDirAbs, err := filepath.Abs(config.GetBinFolderPath())
if err != nil {
return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
}
// Resolve symlinks in base directory to get the real path.
if resolved, evalErr := filepath.EvalSymlinks(baseDirAbs); evalErr == nil {
baseDirAbs = resolved
}
destPathAbs, err := filepath.Abs(destPath)
if err != nil {
return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
}
relToBase, err := filepath.Rel(baseDirAbs, destPathAbs)
if err != nil || strings.HasPrefix(relToBase, "..") || filepath.IsAbs(relToBase) {
return ErrCustomGeoPathTraversal
// Resolve symlinks for the parent directory of the destination path.
destDir := filepath.Dir(destPathAbs)
if resolved, evalErr := filepath.EvalSymlinks(destDir); evalErr == nil {
destPathAbs = filepath.Join(resolved, filepath.Base(destPathAbs))
}
return nil
// Verify the resolved path is within the safe base directory using prefix check.
safeDirPrefix := baseDirAbs + string(filepath.Separator)
if !strings.HasPrefix(destPathAbs, safeDirPrefix) {
return "", ErrCustomGeoPathTraversal
}
return destPathAbs, nil
}
func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
if err := validateDestPath(destPath); err != nil {
destPath, err = sanitizeDestPath(destPath)
if err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
}
@ -477,7 +495,8 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
s.syncLocalPath(r)
r.Id = id
r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
if err := validateDestPath(r.LocalPath); err != nil {
r.LocalPath, err = sanitizeDestPath(r.LocalPath)
if err != nil {
return err
}
if oldPath != r.LocalPath && oldPath != "" {
@ -518,9 +537,9 @@ func (s *CustomGeoService) Delete(id int) (displayName string, err error) {
return "", err
}
displayName = s.fileNameFor(r.Type, r.Alias)
p := s.resolveDestPath(&r)
if err := validateDestPath(p); err != nil {
return displayName, err
p, pathErr := sanitizeDestPath(s.resolveDestPath(&r))
if pathErr != nil {
return displayName, pathErr
}
if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
return displayName, err

View file

@ -1,6 +1,7 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
@ -17,7 +18,7 @@ import (
func disableSSRFCheck(t *testing.T) {
t.Helper()
orig := checkSSRF
checkSSRF = func(string) error { return nil }
checkSSRF = func(_ context.Context, _ string) error { return nil }
t.Cleanup(func() { checkSSRF = orig })
}