mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-06 21:24:10 +00:00
fix
This commit is contained in:
parent
f9fe4b2e5e
commit
554202e7d1
3 changed files with 54 additions and 34 deletions
|
|
@ -67,7 +67,7 @@ func mapCustomGeoErr(c *gin.Context, err error) error {
|
|||
return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost"))
|
||||
case errors.Is(err, service.ErrCustomGeoPathTraversal):
|
||||
logger.Warning("custom geo path traversal blocked:", err)
|
||||
return errors.New(I18nWeb(c, "pages.index.customGeoErrAliasPattern"))
|
||||
return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
|
||||
default:
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ func (s *CustomGeoService) sanitizeURL(raw string) (string, error) {
|
|||
if u.Host == "" {
|
||||
return "", ErrCustomGeoURLHost
|
||||
}
|
||||
if err := checkSSRF(u.Hostname()); err != nil {
|
||||
if err := checkSSRF(context.Background(), u.Hostname()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Reconstruct URL from parsed components to break taint propagation.
|
||||
|
|
@ -165,14 +165,15 @@ func isBlockedIP(ip net.IP) bool {
|
|||
}
|
||||
|
||||
// checkSSRFDefault validates that the given host does not resolve to a private/internal IP.
|
||||
func checkSSRFDefault(hostname string) error {
|
||||
ips, err := net.LookupIP(hostname)
|
||||
// It is context-aware so that dial context cancellation/deadlines are respected during DNS resolution.
|
||||
func checkSSRFDefault(ctx context.Context, hostname string) error {
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isBlockedIP(ip) {
|
||||
return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ip)
|
||||
for _, ipAddr := range ips {
|
||||
if isBlockedIP(ipAddr.IP) {
|
||||
return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ipAddr.IP)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -181,20 +182,24 @@ func checkSSRFDefault(hostname string) error {
|
|||
// checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers.
|
||||
var checkSSRF = checkSSRFDefault
|
||||
|
||||
func ssrfSafeTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
|
||||
}
|
||||
if err := checkSSRF(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
func ssrfSafeTransport() http.RoundTripper {
|
||||
base, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
base = &http.Transport{}
|
||||
}
|
||||
cloned := base.Clone()
|
||||
cloned.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
|
||||
}
|
||||
if err := checkSSRF(ctx, host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func probeCustomGeoURLWithGET(rawURL string) error {
|
||||
|
|
@ -287,25 +292,38 @@ func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModi
|
|||
return false, lm, nil
|
||||
}
|
||||
|
||||
// validateDestPath ensures destPath is inside the bin folder, preventing path traversal.
|
||||
func validateDestPath(destPath string) error {
|
||||
// sanitizeDestPath ensures destPath is inside the bin folder, preventing path traversal.
|
||||
// It resolves symlinks to prevent symlink-based escapes.
|
||||
// Returns the cleaned absolute path that is safe to use in file operations.
|
||||
func sanitizeDestPath(destPath string) (string, error) {
|
||||
baseDirAbs, err := filepath.Abs(config.GetBinFolderPath())
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
|
||||
return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
|
||||
}
|
||||
// Resolve symlinks in base directory to get the real path.
|
||||
if resolved, evalErr := filepath.EvalSymlinks(baseDirAbs); evalErr == nil {
|
||||
baseDirAbs = resolved
|
||||
}
|
||||
destPathAbs, err := filepath.Abs(destPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
|
||||
return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
|
||||
}
|
||||
relToBase, err := filepath.Rel(baseDirAbs, destPathAbs)
|
||||
if err != nil || strings.HasPrefix(relToBase, "..") || filepath.IsAbs(relToBase) {
|
||||
return ErrCustomGeoPathTraversal
|
||||
// Resolve symlinks for the parent directory of the destination path.
|
||||
destDir := filepath.Dir(destPathAbs)
|
||||
if resolved, evalErr := filepath.EvalSymlinks(destDir); evalErr == nil {
|
||||
destPathAbs = filepath.Join(resolved, filepath.Base(destPathAbs))
|
||||
}
|
||||
return nil
|
||||
// Verify the resolved path is within the safe base directory using prefix check.
|
||||
safeDirPrefix := baseDirAbs + string(filepath.Separator)
|
||||
if !strings.HasPrefix(destPathAbs, safeDirPrefix) {
|
||||
return "", ErrCustomGeoPathTraversal
|
||||
}
|
||||
return destPathAbs, nil
|
||||
}
|
||||
|
||||
func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
|
||||
if err := validateDestPath(destPath); err != nil {
|
||||
destPath, err = sanitizeDestPath(destPath)
|
||||
if err != nil {
|
||||
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
|
||||
}
|
||||
|
||||
|
|
@ -477,7 +495,8 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
|
|||
s.syncLocalPath(r)
|
||||
r.Id = id
|
||||
r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
|
||||
if err := validateDestPath(r.LocalPath); err != nil {
|
||||
r.LocalPath, err = sanitizeDestPath(r.LocalPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if oldPath != r.LocalPath && oldPath != "" {
|
||||
|
|
@ -518,9 +537,9 @@ func (s *CustomGeoService) Delete(id int) (displayName string, err error) {
|
|||
return "", err
|
||||
}
|
||||
displayName = s.fileNameFor(r.Type, r.Alias)
|
||||
p := s.resolveDestPath(&r)
|
||||
if err := validateDestPath(p); err != nil {
|
||||
return displayName, err
|
||||
p, pathErr := sanitizeDestPath(s.resolveDestPath(&r))
|
||||
if pathErr != nil {
|
||||
return displayName, pathErr
|
||||
}
|
||||
if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
|
||||
return displayName, err
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
|
@ -17,7 +18,7 @@ import (
|
|||
func disableSSRFCheck(t *testing.T) {
|
||||
t.Helper()
|
||||
orig := checkSSRF
|
||||
checkSSRF = func(string) error { return nil }
|
||||
checkSSRF = func(_ context.Context, _ string) error { return nil }
|
||||
t.Cleanup(func() { checkSSRF = orig })
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue