Add SSRF protection for custom geo downloads

Introduce SSRF-safe HTTP transport for custom geo operations by adding ssrfSafeTransport and isBlockedIP helpers. The transport resolves hosts and blocks loopback, private, link-local and unspecified addresses, returning ErrCustomGeoSSRFBlocked on violations. Update probeCustomGeoURLWithGET, probeCustomGeoURL and downloadToPathOnce to use the safe transport. Also add the new error ErrCustomGeoSSRFBlocked and necessary imports. Minor whitespace/formatting adjustments in subClashService.go, web/entity/entity.go and web/service/setting.go.
This commit is contained in:
MHSanaei 2026-04-19 23:10:17 +02:00
parent d580086361
commit 91ee295199
No known key found for this signature in database
GPG key ID: 7E4060F2FBE5AB7A
6 changed files with 104 additions and 31 deletions

View file

@ -160,10 +160,10 @@ func (s *SubClashService) getProxies(inbound *model.Inbound, client model.Client
func (s *SubClashService) buildProxy(inbound *model.Inbound, client model.Client, stream map[string]any, extraRemark string) map[string]any { func (s *SubClashService) buildProxy(inbound *model.Inbound, client model.Client, stream map[string]any, extraRemark string) map[string]any {
proxy := map[string]any{ proxy := map[string]any{
"name": s.SubService.genRemark(inbound, client.Email, extraRemark), "name": s.SubService.genRemark(inbound, client.Email, extraRemark),
"server": inbound.Listen, "server": inbound.Listen,
"port": inbound.Port, "port": inbound.Port,
"udp": true, "udp": true,
} }
network, _ := stream["network"].(string) network, _ := stream["network"].(string)

View file

@ -62,6 +62,9 @@ func mapCustomGeoErr(c *gin.Context, err error) error {
case errors.Is(err, service.ErrCustomGeoDownload): case errors.Is(err, service.ErrCustomGeoDownload):
logger.Warning("custom geo download:", err) logger.Warning("custom geo download:", err)
return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload")) return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
case errors.Is(err, service.ErrCustomGeoSSRFBlocked):
logger.Warning("custom geo SSRF blocked:", err)
return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost"))
default: default:
return err return err
} }

View file

@ -76,9 +76,9 @@ type AllSetting struct {
SubURI string `json:"subURI" form:"subURI"` // Subscription server URI SubURI string `json:"subURI" form:"subURI"` // Subscription server URI
SubJsonPath string `json:"subJsonPath" form:"subJsonPath"` // Path for JSON subscription endpoint SubJsonPath string `json:"subJsonPath" form:"subJsonPath"` // Path for JSON subscription endpoint
SubJsonURI string `json:"subJsonURI" form:"subJsonURI"` // JSON subscription server URI SubJsonURI string `json:"subJsonURI" form:"subJsonURI"` // JSON subscription server URI
SubClashEnable bool `json:"subClashEnable" form:"subClashEnable"` // Enable Clash/Mihomo subscription endpoint SubClashEnable bool `json:"subClashEnable" form:"subClashEnable"` // Enable Clash/Mihomo subscription endpoint
SubClashPath string `json:"subClashPath" form:"subClashPath"` // Path for Clash/Mihomo subscription endpoint SubClashPath string `json:"subClashPath" form:"subClashPath"` // Path for Clash/Mihomo subscription endpoint
SubClashURI string `json:"subClashURI" form:"subClashURI"` // Clash/Mihomo subscription server URI SubClashURI string `json:"subClashURI" form:"subClashURI"` // Clash/Mihomo subscription server URI
SubJsonFragment string `json:"subJsonFragment" form:"subJsonFragment"` // JSON subscription fragment configuration SubJsonFragment string `json:"subJsonFragment" form:"subJsonFragment"` // JSON subscription fragment configuration
SubJsonNoises string `json:"subJsonNoises" form:"subJsonNoises"` // JSON subscription noise configuration SubJsonNoises string `json:"subJsonNoises" form:"subJsonNoises"` // JSON subscription noise configuration
SubJsonMux string `json:"subJsonMux" form:"subJsonMux"` // JSON subscription mux configuration SubJsonMux string `json:"subJsonMux" form:"subJsonMux"` // JSON subscription mux configuration

View file

@ -1,9 +1,11 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -43,6 +45,7 @@ var (
ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias") ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias")
ErrCustomGeoNotFound = errors.New("custom_geo_not_found") ErrCustomGeoNotFound = errors.New("custom_geo_not_found")
ErrCustomGeoDownload = errors.New("custom_geo_download") ErrCustomGeoDownload = errors.New("custom_geo_download")
ErrCustomGeoSSRFBlocked = errors.New("custom_geo_ssrf_blocked")
) )
type CustomGeoUpdateAllItem struct { type CustomGeoUpdateAllItem struct {
@ -111,21 +114,24 @@ func (s *CustomGeoService) validateAlias(alias string) error {
return nil return nil
} }
func (s *CustomGeoService) validateURL(raw string) error { func (s *CustomGeoService) sanitizeURL(raw string) (string, error) {
if raw == "" { if raw == "" {
return ErrCustomGeoURLRequired return "", ErrCustomGeoURLRequired
} }
u, err := url.Parse(raw) u, err := url.Parse(raw)
if err != nil { if err != nil {
return ErrCustomGeoInvalidURL return "", ErrCustomGeoInvalidURL
} }
if u.Scheme != "http" && u.Scheme != "https" { if u.Scheme != "http" && u.Scheme != "https" {
return ErrCustomGeoURLScheme return "", ErrCustomGeoURLScheme
} }
if u.Host == "" { if u.Host == "" {
return ErrCustomGeoURLHost return "", ErrCustomGeoURLHost
} }
return nil if err := checkSSRF(u.Hostname()); err != nil {
return "", err
}
return u.String(), nil
} }
func localDatFileNeedsRepair(path string) bool { func localDatFileNeedsRepair(path string) bool {
@ -143,8 +149,46 @@ func CustomGeoLocalFileNeedsRepair(path string) bool {
return localDatFileNeedsRepair(path) return localDatFileNeedsRepair(path)
} }
func isBlockedIP(ip net.IP) bool {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsUnspecified()
}
// checkSSRFDefault validates that the given host does not resolve to a private/internal IP.
func checkSSRFDefault(hostname string) error {
ips, err := net.LookupIP(hostname)
if err != nil {
return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname)
}
for _, ip := range ips {
if isBlockedIP(ip) {
return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ip)
}
}
return nil
}
// checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers.
var checkSSRF = checkSSRFDefault
func ssrfSafeTransport() *http.Transport {
return &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
}
if err := checkSSRF(host); err != nil {
return nil, err
}
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
},
}
}
func probeCustomGeoURLWithGET(rawURL string) error { func probeCustomGeoURLWithGET(rawURL string) error {
client := &http.Client{Timeout: customGeoProbeTimeout} client := &http.Client{Timeout: customGeoProbeTimeout, Transport: ssrfSafeTransport()}
req, err := http.NewRequest(http.MethodGet, rawURL, nil) req, err := http.NewRequest(http.MethodGet, rawURL, nil)
if err != nil { if err != nil {
return err return err
@ -165,7 +209,7 @@ func probeCustomGeoURLWithGET(rawURL string) error {
} }
func probeCustomGeoURL(rawURL string) error { func probeCustomGeoURL(rawURL string) error {
client := &http.Client{Timeout: customGeoProbeTimeout} client := &http.Client{Timeout: customGeoProbeTimeout, Transport: ssrfSafeTransport()}
req, err := http.NewRequest(http.MethodHead, rawURL, nil) req, err := http.NewRequest(http.MethodHead, rawURL, nil)
if err != nil { if err != nil {
return err return err
@ -199,10 +243,12 @@ func (s *CustomGeoService) EnsureOnStartup() {
logger.Infof("custom geo startup: checking %d custom geofile(s)", n) logger.Infof("custom geo startup: checking %d custom geofile(s)", n)
for i := range list { for i := range list {
r := &list[i] r := &list[i]
if err := s.validateURL(r.Url); err != nil { sanitizedURL, err := s.sanitizeURL(r.Url)
if err != nil {
logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err) logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err)
continue continue
} }
r.Url = sanitizedURL
s.syncLocalPath(r) s.syncLocalPath(r)
localPath := r.LocalPath localPath := r.LocalPath
if !localDatFileNeedsRepair(localPath) { if !localDatFileNeedsRepair(localPath) {
@ -250,7 +296,7 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
} }
} }
client := &http.Client{Timeout: 10 * time.Minute} client := &http.Client{Timeout: 10 * time.Minute, Transport: ssrfSafeTransport()}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
@ -338,9 +384,11 @@ func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
if err := s.validateAlias(r.Alias); err != nil { if err := s.validateAlias(r.Alias); err != nil {
return err return err
} }
if err := s.validateURL(r.Url); err != nil { sanitizedURL, err := s.sanitizeURL(r.Url)
if err != nil {
return err return err
} }
r.Url = sanitizedURL
var existing int64 var existing int64
database.GetDB().Model(&model.CustomGeoResource{}). database.GetDB().Model(&model.CustomGeoResource{}).
Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing) Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing)
@ -380,9 +428,11 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
if err := s.validateAlias(r.Alias); err != nil { if err := s.validateAlias(r.Alias); err != nil {
return err return err
} }
if err := s.validateURL(r.Url); err != nil { sanitizedURL, err := s.sanitizeURL(r.Url)
if err != nil {
return err return err
} }
r.Url = sanitizedURL
if cur.Type != r.Type || cur.Alias != r.Alias { if cur.Type != r.Type || cur.Alias != r.Alias {
var cnt int64 var cnt int64
database.GetDB().Model(&model.CustomGeoResource{}). database.GetDB().Model(&model.CustomGeoResource{}).
@ -468,7 +518,11 @@ func (s *CustomGeoService) applyDownloadAndPersist(id int, onStartup bool) (disp
} }
displayName = s.fileNameFor(r.Type, r.Alias) displayName = s.fileNameFor(r.Type, r.Alias)
s.syncLocalPath(&r) s.syncLocalPath(&r)
skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified) sanitizedURL, sanitizeErr := s.sanitizeURL(r.Url)
if sanitizeErr != nil {
return displayName, sanitizeErr
}
skipped, lm, err := s.downloadToPath(sanitizedURL, r.LocalPath, r.LastModified)
if err != nil { if err != nil {
if onStartup { if onStartup {
logger.Warningf("custom geo startup download id=%d: %v", id, err) logger.Warningf("custom geo startup download id=%d: %v", id, err)

View file

@ -12,6 +12,15 @@ import (
"github.com/mhsanaei/3x-ui/v2/database/model" "github.com/mhsanaei/3x-ui/v2/database/model"
) )
// disableSSRFCheck disables the SSRF guard for the duration of a test,
// allowing httptest servers on localhost. It restores the original on cleanup.
func disableSSRFCheck(t *testing.T) {
t.Helper()
orig := checkSSRF
checkSSRF = func(string) error { return nil }
t.Cleanup(func() { checkSSRF = orig })
}
func TestNormalizeAliasKey(t *testing.T) { func TestNormalizeAliasKey(t *testing.T) {
if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" { if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" {
t.Fatalf("got %q", got) t.Fatalf("got %q", got)
@ -139,14 +148,16 @@ func TestCustomGeoValidateAlias(t *testing.T) {
func TestCustomGeoValidateURL(t *testing.T) { func TestCustomGeoValidateURL(t *testing.T) {
s := CustomGeoService{} s := CustomGeoService{}
if err := s.validateURL(""); !errors.Is(err, ErrCustomGeoURLRequired) { if _, err := s.sanitizeURL(""); !errors.Is(err, ErrCustomGeoURLRequired) {
t.Fatal("empty") t.Fatal("empty")
} }
if err := s.validateURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) { if _, err := s.sanitizeURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) {
t.Fatal("ftp") t.Fatal("ftp")
} }
if err := s.validateURL("https://example.com/a.dat"); err != nil { if sanitized, err := s.sanitizeURL("https://example.com/a.dat"); err != nil {
t.Fatal(err) t.Fatal(err)
} else if sanitized != "https://example.com/a.dat" {
t.Fatalf("unexpected sanitized URL: %s", sanitized)
} }
} }
@ -161,6 +172,7 @@ func TestCustomGeoValidateType(t *testing.T) {
} }
func TestCustomGeoDownloadToPath(t *testing.T) { func TestCustomGeoDownloadToPath(t *testing.T) {
disableSSRFCheck(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "1") w.Header().Set("X-Test", "1")
if r.Header.Get("If-Modified-Since") != "" { if r.Header.Get("If-Modified-Since") != "" {
@ -193,6 +205,7 @@ func TestCustomGeoDownloadToPath(t *testing.T) {
} }
func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) { func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
disableSSRFCheck(t)
lm := "Wed, 21 Oct 2015 07:28:00 GMT" lm := "Wed, 21 Oct 2015 07:28:00 GMT"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("If-Modified-Since") != "" { if r.Header.Get("If-Modified-Since") != "" {
@ -221,6 +234,7 @@ func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
} }
func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) { func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) {
disableSSRFCheck(t)
lm := "Wed, 21 Oct 2015 07:28:00 GMT" lm := "Wed, 21 Oct 2015 07:28:00 GMT"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("If-Modified-Since") != "" { if r.Header.Get("If-Modified-Since") != "" {
@ -297,6 +311,7 @@ func TestLocalDatFileNeedsRepair(t *testing.T) {
} }
func TestProbeCustomGeoURL_HEADOK(t *testing.T) { func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
disableSSRFCheck(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead { if r.Method == http.MethodHead {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -311,6 +326,7 @@ func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
} }
func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) { func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) {
disableSSRFCheck(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead { if r.Method == http.MethodHead {
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)

View file

@ -758,13 +758,13 @@ func extractHostname(host string) string {
func (s *SettingService) GetDefaultSettings(host string) (any, error) { func (s *SettingService) GetDefaultSettings(host string) (any, error) {
type settingFunc func() (any, error) type settingFunc func() (any, error)
settings := map[string]settingFunc{ settings := map[string]settingFunc{
"expireDiff": func() (any, error) { return s.GetExpireDiff() }, "expireDiff": func() (any, error) { return s.GetExpireDiff() },
"trafficDiff": func() (any, error) { return s.GetTrafficDiff() }, "trafficDiff": func() (any, error) { return s.GetTrafficDiff() },
"pageSize": func() (any, error) { return s.GetPageSize() }, "pageSize": func() (any, error) { return s.GetPageSize() },
"defaultCert": func() (any, error) { return s.GetCertFile() }, "defaultCert": func() (any, error) { return s.GetCertFile() },
"defaultKey": func() (any, error) { return s.GetKeyFile() }, "defaultKey": func() (any, error) { return s.GetKeyFile() },
"tgBotEnable": func() (any, error) { return s.GetTgbotEnabled() }, "tgBotEnable": func() (any, error) { return s.GetTgbotEnabled() },
"subEnable": func() (any, error) { return s.GetSubEnable() }, "subEnable": func() (any, error) { return s.GetSubEnable() },
"subJsonEnable": func() (any, error) { return s.GetSubJsonEnable() }, "subJsonEnable": func() (any, error) { return s.GetSubJsonEnable() },
"subClashEnable": func() (any, error) { return s.GetSubClashEnable() }, "subClashEnable": func() (any, error) { return s.GetSubClashEnable() },
"subTitle": func() (any, error) { return s.GetSubTitle() }, "subTitle": func() (any, error) { return s.GetSubTitle() },
@ -772,8 +772,8 @@ func (s *SettingService) GetDefaultSettings(host string) (any, error) {
"subJsonURI": func() (any, error) { return s.GetSubJsonURI() }, "subJsonURI": func() (any, error) { return s.GetSubJsonURI() },
"subClashURI": func() (any, error) { return s.GetSubClashURI() }, "subClashURI": func() (any, error) { return s.GetSubClashURI() },
"remarkModel": func() (any, error) { return s.GetRemarkModel() }, "remarkModel": func() (any, error) { return s.GetRemarkModel() },
"datepicker": func() (any, error) { return s.GetDatepicker() }, "datepicker": func() (any, error) { return s.GetDatepicker() },
"ipLimitEnable": func() (any, error) { return s.GetIpLimitEnable() }, "ipLimitEnable": func() (any, error) { return s.GetIpLimitEnable() },
} }
result := make(map[string]any) result := make(map[string]any)