diff --git a/.codex b/.codex new file mode 100644 index 00000000..e69de29b diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..f6646be9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,2 @@ +每次修改go代码都要执行gofmt -w . +并进行run tests测试 \ No newline at end of file diff --git a/config/config.go b/config/config.go index 2c8a8b85..4f994292 100644 --- a/config/config.go +++ b/config/config.go @@ -213,6 +213,26 @@ func settingsLayoutMeta() map[string]any { } } +func ensureDefaultNodeSettings(settings map[string]any) { + group, ok := settings["other"].(map[string]any) + if !ok { + group = make(map[string]any) + settings["other"] = group + } + + defaults := map[string]string{ + "nodeRole": string(NodeRoleMaster), + "nodeId": "", + "syncInterval": "30", + "trafficFlushInterval": "10", + } + for key, value := range defaults { + if existing, exists := group[key]; !exists || existing == nil { + group[key] = value + } + } +} + func copyFile(src, dst string) error { in, err := os.Open(src) if err != nil { @@ -361,18 +381,26 @@ func ValidateNodeConfig(nodeCfg NodeConfig, dbCfg DBConfig) error { // It reads the existing file, updates the value, and writes back. func WriteSettingToJSON(key, value string) error { path := GetSettingPath() - data, err := os.ReadFile(path) - if err != nil { - return err - } - var settings map[string]any - if err := json.Unmarshal(data, &settings); err != nil { + data, err := os.ReadFile(path) + if err == nil { + if err := json.Unmarshal(data, &settings); err != nil { + return err + } + } else if os.IsNotExist(err) { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + settings = map[string]any{ + "_meta": settingsLayoutMeta(), + } + } else { return err } if _, exists := settings["_meta"]; !exists { settings["_meta"] = settingsLayoutMeta() } + ensureDefaultNodeSettings(settings) // Check if the key lives in a nested group if groups, ok := settingGroupAliases[key]; ok && len(groups) > 0 { diff --git a/config/config_test.go b/config/config_test.go index 3f0cb147..fcef9cf3 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -268,3 +268,100 @@ func TestWriteSettingToJSONUsesModulePurposeGroup(t *testing.T) { t.Fatalf("expected databaseConnection.dbHost to be updated, got %v", group["dbHost"]) } } + +func TestWriteSettingToJSONCreatesSettingsFileWhenMissing(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XUI_DB_FOLDER", tmpDir) + + if err := WriteSettingToJSON("dbType", "mariadb"); err != nil { + t.Fatalf("WriteSettingToJSON error: %v", err) + } + if err := WriteSettingToJSON("dbHost", "127.0.0.1"); err != nil { + t.Fatalf("WriteSettingToJSON error: %v", err) + } + + data, err := os.ReadFile(GetSettingPath()) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + group, ok := parsed["databaseConnection"].(map[string]any) + if !ok { + t.Fatalf("expected databaseConnection group, got %T", parsed["databaseConnection"]) + } + if got, ok := group["dbType"].(string); !ok || got != "mariadb" { + t.Fatalf("expected databaseConnection.dbType to be updated, got %v", group["dbType"]) + } + if got, ok := group["dbHost"].(string); !ok || got != "127.0.0.1" { + t.Fatalf("expected databaseConnection.dbHost to be updated, got %v", group["dbHost"]) + } + other, ok := parsed["other"].(map[string]any) + if !ok { + t.Fatalf("expected other group, got %T", parsed["other"]) + } + if got, ok := other["nodeRole"].(string); !ok || got != "master" { + t.Fatalf("expected other.nodeRole to default to master, got %v", other["nodeRole"]) + } + if got, ok := other["nodeId"].(string); !ok || got != "" { + t.Fatalf("expected other.nodeId to default to empty string, got %v", other["nodeId"]) + } + if got, ok := other["syncInterval"].(string); !ok || got != "30" { + t.Fatalf("expected other.syncInterval to default to 30, got %v", other["syncInterval"]) + } + if got, ok := other["trafficFlushInterval"].(string); !ok || got != "10" { + t.Fatalf("expected other.trafficFlushInterval to default to 10, got %v", other["trafficFlushInterval"]) + } +} + +func TestWriteSettingToJSONBackfillsDefaultNodeSettings(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XUI_DB_FOLDER", tmpDir) + + initial := map[string]any{ + "_meta": map[string]any{ + "layout": "按模块-用途来归类", + }, + "databaseConnection": map[string]any{ + "dbType": "mariadb", + }, + } + data, err := json.MarshalIndent(initial, "", " ") + if err != nil { + t.Fatalf("MarshalIndent error: %v", err) + } + if err := os.WriteFile(GetSettingPath(), data, 0644); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + + if err := WriteSettingToJSON("dbHost", "127.0.0.1"); err != nil { + t.Fatalf("WriteSettingToJSON error: %v", err) + } + + updated, err := os.ReadFile(GetSettingPath()) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + var parsed map[string]any + if err := json.Unmarshal(updated, &parsed); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + other, ok := parsed["other"].(map[string]any) + if !ok { + t.Fatalf("expected other group, got %T", parsed["other"]) + } + if got, ok := other["nodeRole"].(string); !ok || got != "master" { + t.Fatalf("expected other.nodeRole to default to master, got %v", other["nodeRole"]) + } + if got, ok := other["nodeId"].(string); !ok || got != "" { + t.Fatalf("expected other.nodeId to default to empty string, got %v", other["nodeId"]) + } + if got, ok := other["syncInterval"].(string); !ok || got != "30" { + t.Fatalf("expected other.syncInterval to default to 30, got %v", other["syncInterval"]) + } + if got, ok := other["trafficFlushInterval"].(string); !ok || got != "10" { + t.Fatalf("expected other.trafficFlushInterval to default to 10, got %v", other["trafficFlushInterval"]) + } +} diff --git a/database/db_test.go b/database/db_test.go index a64cb501..2534923f 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -2,11 +2,17 @@ package database import ( "bytes" + "log" "os" "path/filepath" + "strings" "testing" + "time" "github.com/mhsanaei/3x-ui/v2/database/model" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + glogger "gorm.io/gorm/logger" ) func setupTestDB(t *testing.T) { @@ -142,6 +148,11 @@ func TestInitDB_Idempotent(t *testing.T) { } CloseDB() + startSecond := time.Now().Unix() + for time.Now().Unix() == startSecond { + time.Sleep(10 * time.Millisecond) + } + // Second init on the same file should not fail if err := InitDBWithPath(dbPath); err != nil { t.Fatalf("second InitDB failed: %v", err) @@ -296,6 +307,90 @@ func TestBumpSharedAccountsVersion(t *testing.T) { } } +func TestSeedSharedAccountsVersion_UsesPrimaryKeyLookup(t *testing.T) { + var logs bytes.Buffer + dryLogger := glogger.New(log.New(&logs, "", 0), glogger.Config{ + LogLevel: glogger.Info, + Colorful: false, + }) + + dryDB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ + DryRun: true, + Logger: dryLogger, + }) + if err != nil { + t.Fatalf("open dry-run DB failed: %v", err) + } + + if err := seedSharedAccountsVersion(dryDB); err != nil { + t.Fatalf("seedSharedAccountsVersion error: %v", err) + } + + sqlLogs := logs.String() + if !strings.Contains(sqlLogs, "WHERE `shared_states`.`key` = \"shared_accounts_version\"") { + t.Fatalf("expected primary-key lookup SQL, got logs:\n%s", sqlLogs) + } + if strings.Contains(sqlLogs, "WHERE key = \"shared_accounts_version\"") { + t.Fatalf("expected seed query to avoid raw key lookup, got logs:\n%s", sqlLogs) + } +} + +func TestGetSharedAccountsVersion_UsesPrimaryKeyLookup(t *testing.T) { + var logs bytes.Buffer + dryLogger := glogger.New(log.New(&logs, "", 0), glogger.Config{ + LogLevel: glogger.Info, + Colorful: false, + }) + + dryDB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ + DryRun: true, + Logger: dryLogger, + }) + if err != nil { + t.Fatalf("open dry-run DB failed: %v", err) + } + + if _, err := GetSharedAccountsVersion(dryDB); err != nil { + t.Fatalf("GetSharedAccountsVersion error: %v", err) + } + + sqlLogs := logs.String() + if !strings.Contains(sqlLogs, "WHERE `shared_states`.`key` = \"shared_accounts_version\"") { + t.Fatalf("expected primary-key lookup SQL, got logs:\n%s", sqlLogs) + } + if strings.Contains(sqlLogs, "WHERE key = \"shared_accounts_version\"") { + t.Fatalf("expected version lookup to avoid raw key lookup, got logs:\n%s", sqlLogs) + } +} + +func TestBumpSharedAccountsVersion_UsesQuotedKeyColumn(t *testing.T) { + var logs bytes.Buffer + dryLogger := glogger.New(log.New(&logs, "", 0), glogger.Config{ + LogLevel: glogger.Info, + Colorful: false, + }) + + dryDB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ + DryRun: true, + Logger: dryLogger, + }) + if err != nil { + t.Fatalf("open dry-run DB failed: %v", err) + } + + if err := BumpSharedAccountsVersion(dryDB); err != nil { + t.Fatalf("BumpSharedAccountsVersion error: %v", err) + } + + sqlLogs := logs.String() + if !strings.Contains(sqlLogs, "WHERE `shared_states`.`key` = \"shared_accounts_version\"") { + t.Fatalf("expected quoted key column in update SQL, got logs:\n%s", sqlLogs) + } + if strings.Contains(sqlLogs, "WHERE key = \"shared_accounts_version\"") { + t.Fatalf("expected update SQL to avoid raw key lookup, got logs:\n%s", sqlLogs) + } +} + func TestUpsertNodeState(t *testing.T) { setupTestDB(t) diff --git a/database/shared_state.go b/database/shared_state.go index 1b4d5859..552128ce 100644 --- a/database/shared_state.go +++ b/database/shared_state.go @@ -17,19 +17,22 @@ func txOrDB(tx *gorm.DB) *gorm.DB { } func seedSharedAccountsVersion(tx *gorm.DB) error { - return txOrDB(tx).FirstOrCreate( - &model.SharedState{}, - &model.SharedState{ - Key: SharedAccountsVersionKey, + state := &model.SharedState{ + Key: SharedAccountsVersionKey, + } + return txOrDB(tx). + Attrs(&model.SharedState{ Version: 0, UpdatedAt: time.Now().Unix(), - }, - ).Error + }). + FirstOrCreate(state).Error } func GetSharedAccountsVersion(tx *gorm.DB) (int64, error) { - state := &model.SharedState{} - if err := txOrDB(tx).First(state, "key = ?", SharedAccountsVersionKey).Error; err != nil { + state := &model.SharedState{ + Key: SharedAccountsVersionKey, + } + if err := txOrDB(tx).First(state).Error; err != nil { return 0, err } return state.Version, nil @@ -37,7 +40,7 @@ func GetSharedAccountsVersion(tx *gorm.DB) (int64, error) { func BumpSharedAccountsVersion(tx *gorm.DB) error { return txOrDB(tx).Model(&model.SharedState{}). - Where("key = ?", SharedAccountsVersionKey). + Where(&model.SharedState{Key: SharedAccountsVersionKey}). Updates(map[string]any{ "version": gorm.Expr("version + 1"), "updated_at": time.Now().Unix(), diff --git a/install.sh b/install.sh index e92a3f12..508b9845 100644 --- a/install.sh +++ b/install.sh @@ -6,8 +6,6 @@ blue='\033[0;34m' yellow='\033[0;33m' plain='\033[0m' -cur_dir=$(pwd) - xui_folder="${XUI_MAIN_FOLDER:=/usr/local/x-ui}" xui_service="${XUI_SERVICE:=/etc/systemd/system}" @@ -49,9 +47,6 @@ is_ipv4() { is_ipv6() { [[ "$1" =~ : ]] && return 0 || return 1 } -is_ip() { - is_ipv4 "$1" || is_ipv6 "$1" -} is_domain() { [[ "$1" =~ ^([A-Za-z0-9](-*[A-Za-z0-9])*\.)+(xn--[a-z0-9]{2,}|[A-Za-z]{2,})$ ]] && return 0 || return 1 } @@ -60,11 +55,11 @@ is_domain() { is_port_in_use() { local port="$1" if command -v ss >/dev/null 2>&1; then - ss -ltn 2>/dev/null | awk -v p=":${port}$" '$4 ~ p {exit 0} END {exit 1}' + ss -ltn 2>/dev/null | awk -v p=":${port}$" '$4 ~ p {found=1} END {exit(found ? 0 : 1)}' return fi if command -v netstat >/dev/null 2>&1; then - netstat -lnt 2>/dev/null | awk -v p=":${port} " '$4 ~ p {exit 0} END {exit 1}' + netstat -lnt 2>/dev/null | awk -v p=":${port} " '$4 ~ p {found=1} END {exit(found ? 0 : 1)}' return fi if command -v lsof >/dev/null 2>&1; then @@ -103,6 +98,242 @@ install_base() { esac } +has_mariadb_cli() { + command -v mariadb >/dev/null 2>&1 || command -v mysql >/dev/null 2>&1 +} + +has_local_mariadb_service() { + if command -v systemctl >/dev/null 2>&1; then + systemctl list-unit-files 2>/dev/null | grep -qE '^(mariadb|mysql)\.service$' && return 0 + fi + [[ -f /etc/init.d/mariadb ]] +} + +mariadb_cli_bin() { + if command -v mariadb >/dev/null 2>&1; then + command -v mariadb + return 0 + fi + if command -v mysql >/dev/null 2>&1; then + command -v mysql + return 0 + fi + return 1 +} + +install_mariadb_client() { + echo -e "${green}正在安装 MariaDB 客户端...${plain}" + case "${release}" in + ubuntu | debian | armbian | linuxmint) + apt-get update -y && apt-get install -y mariadb-client + ;; + fedora) + dnf install -y mariadb + ;; + centos | rhel | almalinux | rocky | ol | alinux | amzn) + if command -v dnf >/dev/null 2>&1; then + dnf install -y mariadb + else + yum install -y mariadb + fi + ;; + arch | manjaro | parch) + pacman -Sy --noconfirm mariadb-clients >/dev/null 2>&1 || pacman -Sy --noconfirm mariadb + ;; + opensuse-tumbleweed | opensuse-leap | sles) + zypper install -y mariadb-client + ;; + alpine) + apk add mariadb-client + ;; + *) + apt-get update -y && apt-get install -y mariadb-client + ;; + esac +} + +install_local_mariadb_server() { + echo -e "${green}正在安装本地 MariaDB...${plain}" + case "${release}" in + ubuntu | debian | armbian | linuxmint) + apt-get update -y && apt-get install -y mariadb-server mariadb-client + ;; + fedora) + dnf install -y mariadb-server mariadb + ;; + centos | rhel | almalinux | rocky | ol | alinux | amzn) + if command -v dnf >/dev/null 2>&1; then + dnf install -y mariadb-server mariadb + else + yum install -y mariadb-server mariadb + fi + ;; + arch | manjaro | parch) + pacman -Sy --noconfirm mariadb + mariadb-install-db --user=mysql --basedir=/usr --datadir=/var/lib/mysql >/dev/null 2>&1 || true + ;; + opensuse-tumbleweed | opensuse-leap | sles) + zypper install -y mariadb-server mariadb-client + ;; + alpine) + apk add mariadb mariadb-client + mariadb-install-db --user=mysql --basedir=/usr --datadir=/var/lib/mysql >/dev/null 2>&1 || true + ;; + *) + echo -e "${red}不支持的发行版: ${release},请手动安装 MariaDB${plain}" + return 1 + ;; + esac +} + +start_mariadb_service() { + local svc_name="" + if command -v systemctl >/dev/null 2>&1; then + if systemctl list-unit-files 2>/dev/null | grep -q "^mariadb.service"; then + svc_name="mariadb" + elif systemctl list-unit-files 2>/dev/null | grep -q "^mysql.service"; then + svc_name="mysql" + fi + fi + + if [ -n "$svc_name" ]; then + systemctl start "$svc_name" 2>/dev/null + systemctl enable "$svc_name" 2>/dev/null + return 0 + fi + + if [[ $release == "alpine" ]]; then + rc-service mariadb start 2>/dev/null + rc-update add mariadb 2>/dev/null + return $? + fi + + return 1 +} + +ensure_mariadb_client_ready() { + if has_mariadb_cli; then + return 0 + fi + install_mariadb_client || return 1 + has_mariadb_cli +} + +ensure_local_mariadb_ready() { + if ! has_local_mariadb_service; then + install_local_mariadb_server || return 1 + fi + ensure_mariadb_client_ready || return 1 + start_mariadb_service || true + return 0 +} + +test_mariadb_server_connection() { + local host="$1" port="$2" user="$3" pass="$4" + local bin + bin=$(mariadb_cli_bin) || return 1 + "$bin" -h "$host" -P "$port" -u "$user" -p"$pass" -e "SELECT 1;" >/dev/null 2>&1 +} + +test_mariadb_database_connection() { + local host="$1" port="$2" dbname="$3" user="$4" pass="$5" + local bin + bin=$(mariadb_cli_bin) || return 1 + "$bin" -h "$host" -P "$port" -u "$user" -p"$pass" -D "$dbname" -e "SELECT 1;" >/dev/null 2>&1 +} + +is_safe_mariadb_identifier() { + [[ "$1" =~ ^[A-Za-z0-9_.-]+$ ]] +} + +escape_sql_string() { + printf "%s" "$1" | sed "s/'/''/g" +} + +LOCAL_MARIADB_ADMIN_MODE="" +LOCAL_MARIADB_ADMIN_USER="" +LOCAL_MARIADB_ADMIN_PASS="" +LOCAL_MARIADB_ADMIN_PORT="3306" + +try_local_mariadb_socket_admin() { + local bin + bin=$(mariadb_cli_bin) || return 1 + "$bin" -e "SELECT 1;" >/dev/null 2>&1 || "$bin" -uroot -e "SELECT 1;" >/dev/null 2>&1 +} + +ensure_local_mariadb_admin_access() { + local port="${1:-3306}" + LOCAL_MARIADB_ADMIN_PORT="$port" + + if try_local_mariadb_socket_admin; then + LOCAL_MARIADB_ADMIN_MODE="socket" + return 0 + fi + + local admin_user admin_pass + echo -e "${yellow}无法通过 root socket 直接连接本地 MariaDB,请输入管理员账号信息。${plain}" + read -rp "MariaDB 管理员用户名 [root]: " admin_user + admin_user="${admin_user:-root}" + read -rsp "MariaDB 管理员密码: " admin_pass + echo + + if ! test_mariadb_server_connection "127.0.0.1" "$port" "$admin_user" "$admin_pass"; then + echo -e "${red}管理员账号连接失败${plain}" + return 1 + fi + + LOCAL_MARIADB_ADMIN_MODE="password" + LOCAL_MARIADB_ADMIN_USER="$admin_user" + LOCAL_MARIADB_ADMIN_PASS="$admin_pass" +} + +run_local_mariadb_admin_sql() { + local sql="$1" + local bin + bin=$(mariadb_cli_bin) || return 1 + + case "$LOCAL_MARIADB_ADMIN_MODE" in + socket) + "$bin" -e "$sql" >/dev/null 2>&1 || "$bin" -uroot -e "$sql" >/dev/null 2>&1 + ;; + password) + "$bin" -h "127.0.0.1" -P "$LOCAL_MARIADB_ADMIN_PORT" -u "$LOCAL_MARIADB_ADMIN_USER" -p"$LOCAL_MARIADB_ADMIN_PASS" -e "$sql" >/dev/null 2>&1 + ;; + *) + return 1 + ;; + esac +} + +ensure_mariadb_database_and_user() { + local dbname="$1" dbuser="$2" dbpass="$3" + local escaped_pass + local sql="" + local account_host="" + + if ! is_safe_mariadb_identifier "$dbname"; then + echo -e "${red}业务数据库名仅支持字母、数字、点、下划线和连字符${plain}" + return 1 + fi + if ! is_safe_mariadb_identifier "$dbuser"; then + echo -e "${red}业务用户名仅支持字母、数字、点、下划线和连字符${plain}" + return 1 + fi + + escaped_pass=$(escape_sql_string "$dbpass") + sql="CREATE DATABASE IF NOT EXISTS \`${dbname}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;" + + for account_host in "localhost" "127.0.0.1" "::1"; do + sql="${sql} CREATE USER IF NOT EXISTS '${dbuser}'@'${account_host}' IDENTIFIED BY '${escaped_pass}';" + sql="${sql} ALTER USER '${dbuser}'@'${account_host}' IDENTIFIED BY '${escaped_pass}';" + sql="${sql} GRANT ALL PRIVILEGES ON \`${dbname}\`.* TO '${dbuser}'@'${account_host}';" + done + sql="${sql} FLUSH PRIVILEGES;" + + echo -e "${green}正在确保本地 MariaDB 的业务库和业务账号存在...${plain}" + run_local_mariadb_admin_sql "$sql" +} + gen_random_string() { local length="$1" openssl rand -base64 $(( length * 2 )) \ @@ -110,6 +341,22 @@ gen_random_string() { | head -c "$length" } +is_safe_install_path() { + local target="$1" + local resolved_target + + [[ -n "$target" ]] || return 1 + resolved_target=$(readlink -f "$target" 2>/dev/null || echo "$target") + + case "$resolved_target" in + "/" | "/usr" | "/usr/" | "/usr/local" | "/usr/local/" | "/etc" | "/etc/") + return 1 + ;; + esac + + return 0 +} + save_panel_domain() { local domain="$1" if [[ -z "$domain" ]]; then @@ -153,10 +400,16 @@ verify_panel_cert_paths() { } install_acme() { + local previous_dir + local install_status + echo -e "${green}正在安装 acme.sh 用于 SSL 证书管理...${plain}" + previous_dir=$(pwd) cd ~ || return 1 curl -s https://get.acme.sh | sh >/dev/null 2>&1 - if [ $? -ne 0 ]; then + install_status=$? + cd "$previous_dir" >/dev/null 2>&1 || true + if [ $install_status -ne 0 ]; then echo -e "${red}安装 acme.sh 失败${plain}" return 1 else @@ -165,76 +418,6 @@ install_acme() { return 0 } -setup_ssl_certificate() { - local domain="$1" - local server_ip="$2" - local existing_port="$3" - local existing_webBasePath="$4" - - echo -e "${green}正在配置 SSL 证书...${plain}" - - # 检查 acme.sh 是否已安装 - if ! command -v ~/.acme.sh/acme.sh &>/dev/null; then - install_acme - if [ $? -ne 0 ]; then - echo -e "${yellow}安装 acme.sh 失败,跳过 SSL 配置${plain}" - return 1 - fi - fi - - # 创建证书目录 - local certPath="/root/cert/${domain}" - mkdir -p "$certPath" - - # 签发证书 - echo -e "${green}正在为 ${domain} 签发 SSL 证书...${plain}" - echo -e "${yellow}注意:80 端口必须开放且可从外网访问${plain}" - - ~/.acme.sh/acme.sh --set-default-ca --server letsencrypt --force >/dev/null 2>&1 - ~/.acme.sh/acme.sh --issue -d ${domain} --listen-v6 --standalone --httpport 80 --force - - if [ $? -ne 0 ]; then - echo -e "${yellow}为 ${domain} 签发证书失败${plain}" - echo -e "${yellow}请确保 80 端口已开放,稍后可通过以下命令重试:x-ui${plain}" - rm -rf ~/.acme.sh/${domain} 2>/dev/null - rm -rf "$certPath" 2>/dev/null - return 1 - fi - - # 安装证书 - ~/.acme.sh/acme.sh --installcert -d ${domain} \ - --key-file /root/cert/${domain}/privkey.pem \ - --fullchain-file /root/cert/${domain}/fullchain.pem \ - --reloadcmd "systemctl restart x-ui" >/dev/null 2>&1 - - if [ $? -ne 0 ]; then - echo -e "${yellow}安装证书失败${plain}" - return 1 - fi - - # 启用自动续期 - ~/.acme.sh/acme.sh --upgrade --auto-upgrade >/dev/null 2>&1 - # 安全权限:私钥仅所有者可读 - chmod 600 $certPath/privkey.pem 2>/dev/null - chmod 644 $certPath/fullchain.pem 2>/dev/null - - # 为面板设置证书 - local webCertFile="/root/cert/${domain}/fullchain.pem" - local webKeyFile="/root/cert/${domain}/privkey.pem" - - if [[ -f "$webCertFile" && -f "$webKeyFile" ]]; then - ${xui_folder}/x-ui cert -webCert "$webCertFile" -webCertKey "$webKeyFile" >/dev/null 2>&1 - if ! verify_panel_cert_paths "$webCertFile" "$webKeyFile"; then - return 1 - fi - echo -e "${green}SSL 证书安装并配置成功!${plain}" - return 0 - else - echo -e "${yellow}未找到证书文件${plain}" - return 1 - fi -} - # 签发 Let's Encrypt IP 证书(短期配置文件,约 6 天有效期) # 需要 acme.sh 且 80 端口开放用于 HTTP-01 验证 setup_ip_certificate() { @@ -370,16 +553,17 @@ setup_ip_certificate() { # 为面板配置证书路径 echo -e "${green}正在为面板设置证书路径...${plain}" - ${xui_folder}/x-ui cert -webCert "${certDir}/fullchain.pem" -webCertKey "${certDir}/privkey.pem" - - if [ $? -ne 0 ]; then - echo -e "${yellow}警告:无法自动设置证书路径${plain}" + if ! "${xui_folder}/x-ui" cert -webCert "${certDir}/fullchain.pem" -webCertKey "${certDir}/privkey.pem" >/dev/null 2>&1; then + echo -e "${red}无法自动设置证书路径${plain}" echo -e "${yellow}证书文件位于:${plain}" echo -e " 证书:${certDir}/fullchain.pem" echo -e " 密钥:${certDir}/privkey.pem" - else - echo -e "${green}证书路径配置成功${plain}" + return 1 fi + if ! verify_panel_cert_paths "${certDir}/fullchain.pem" "${certDir}/privkey.pem"; then + return 1 + fi + echo -e "${green}证书路径配置成功${plain}" echo -e "${green}IP 证书安装并配置成功!${plain}" echo -e "${green}证书有效期约 6 天,通过 acme.sh cron 任务自动续期。${plain}" @@ -395,13 +579,9 @@ ssl_cert_issue() { # 检查 acme.sh if ! command -v ~/.acme.sh/acme.sh &>/dev/null; then echo "未找到 acme.sh,正在安装..." - cd ~ || return 1 - curl -s https://get.acme.sh | sh - if [ $? -ne 0 ]; then + if ! install_acme; then echo -e "${red}安装 acme.sh 失败${plain}" return 1 - else - echo -e "${green}acme.sh 安装成功${plain}" fi fi @@ -426,8 +606,7 @@ ssl_cert_issue() { echo -e "${green}您的域名是:${domain},正在检查...${plain}" # 检查是否已存在证书 - local currentCert=$(~/.acme.sh/acme.sh --list | tail -1 | awk '{print $1}') - if [ "${currentCert}" == "${domain}" ]; then + if ~/.acme.sh/acme.sh --list 2>/dev/null | awk 'NR>1 {print $1}' | grep -Fxq "${domain}"; then local certInfo=$(~/.acme.sh/acme.sh --list) echo -e "${red}系统已有该域名的证书,无法重复签发。${plain}" echo -e "${yellow}当前证书信息:${plain}" @@ -449,7 +628,9 @@ ssl_cert_issue() { # 获取独立服务器端口号 local WebPort=80 read -rp "请选择要使用的端口(默认 80):" WebPort - if [[ ${WebPort} -gt 65535 || ${WebPort} -lt 1 ]]; then + WebPort="${WebPort// /}" + WebPort="${WebPort:-80}" + if ! [[ "${WebPort}" =~ ^[0-9]+$ ]] || ((WebPort < 1 || WebPort > 65535)); then echo -e "${yellow}输入 ${WebPort} 无效,将使用默认端口 80。${plain}" WebPort=80 fi @@ -547,11 +728,14 @@ ssl_cert_issue() { echo -e "${green}访问地址:https://${domain}:${existing_port}/${existing_webBasePath}${plain}" echo -e "${yellow}面板将重启以应用 SSL 证书...${plain}" systemctl restart x-ui 2>/dev/null || rc-service x-ui restart 2>/dev/null + SSL_HOST="${domain}" else echo -e "${red}错误:未找到域名 ${domain} 的证书或私钥文件。${plain}" + return 1 fi else - echo -e "${yellow}跳过面板路径设置。${plain}" + echo -e "${yellow}未将证书应用到面板,SSL 配置未完成。${plain}" + return 1 fi return 0 @@ -565,6 +749,7 @@ prompt_and_setup_ssl() { local server_ip="$3" local ssl_choice="" + SSL_HOST="" echo -e "${yellow}选择 SSL 证书配置方式:${plain}" echo -e "${green}1.${plain} Let's Encrypt 域名证书(90 天有效期,自动续期)" @@ -575,7 +760,7 @@ prompt_and_setup_ssl() { read -rp "请选择(默认 2 使用 IP):" ssl_choice ssl_choice="${ssl_choice// /}" # 去除空格 - # 如果输入为空或无效(非 1、3 或 4),默认为 2(IP 证书) + # 除 1/3/4 外,其余输入均视为 2(IP 证书) if [[ "$ssl_choice" != "1" && "$ssl_choice" != "3" && "$ssl_choice" != "4" ]]; then ssl_choice="2" fi @@ -584,25 +769,36 @@ prompt_and_setup_ssl() { 1) # 用户选择 Let's Encrypt 域名选项 echo -e "${green}使用 Let's Encrypt 域名证书...${plain}" - ssl_cert_issue - # 从证书中提取使用的域名 - local cert_domain=$(~/.acme.sh/acme.sh --list 2>/dev/null | tail -1 | awk '{print $1}') - if [[ -n "${cert_domain}" ]]; then - if ! save_panel_domain "${cert_domain}"; then - SSL_HOST="${server_ip}" - return 1 - fi - SSL_HOST="${cert_domain}" - echo -e "${green}✓ SSL 证书配置成功,域名:${cert_domain}${plain}" + if ssl_cert_issue; then + echo -e "${green}✓ SSL 证书配置成功,域名:${SSL_HOST}${plain}" else - echo -e "${yellow}SSL 配置可能已完成,但域名提取失败${plain}" + echo -e "${red}✗ 域名证书配置失败。${plain}" SSL_HOST="${server_ip}" + return 1 fi ;; 2) # 用户选择 Let's Encrypt IP 证书选项 echo -e "${green}使用 Let's Encrypt IP 证书(短期配置文件)...${plain}" + if [[ -z "${server_ip}" ]]; then + local manual_ipv4="" + echo -e "${yellow}未能自动检测到服务器公网 IPv4。${plain}" + while true; do + read -rp "请输入服务器公网 IPv4(留空取消):" manual_ipv4 + manual_ipv4="${manual_ipv4// /}" + if [[ -z "${manual_ipv4}" ]]; then + echo -e "${red}未提供公网 IPv4,无法继续 IP 证书配置。${plain}" + return 1 + fi + if is_ipv4 "${manual_ipv4}"; then + server_ip="${manual_ipv4}" + break + fi + echo -e "${red}无效的 IPv4 地址:${manual_ipv4}${plain}" + done + fi + # 询问可选的 IPv6 local ipv6_addr="" read -rp "是否包含 IPv6 地址?(留空跳过):" ipv6_addr @@ -622,6 +818,7 @@ prompt_and_setup_ssl() { else echo -e "${red}✗ IP 证书配置失败。请检查 80 端口是否已开放。${plain}" SSL_HOST="${server_ip}" + return 1 fi ;; 3) @@ -711,8 +908,14 @@ prompt_and_setup_ssl() { SSL_HOST="${server_ip}" return 1 fi + if ! is_domain "$cf_domain"; then + echo -e "${red}无效的域名格式:${cf_domain}${plain}" + SSL_HOST="${server_ip}" + return 1 + fi - read -rp "请输入 Cloudflare 全局 API 密钥:" cf_key + read -rsp "请输入 Cloudflare 全局 API 密钥:" cf_key + echo cf_key="${cf_key// /}" if [[ -z "$cf_key" ]]; then echo -e "${red}API 密钥不能为空,跳过 SSL 配置。${plain}" @@ -755,6 +958,7 @@ prompt_and_setup_ssl() { echo -e "${yellow}正在通过 Cloudflare DNS 签发证书...${plain}" ~/.acme.sh/acme.sh --issue --dns dns_cf -d "${cf_domain}" -d "*.${cf_domain}" --log --force if [ $? -ne 0 ]; then + unset CF_Key CF_Email echo -e "${red}证书签发失败,请检查 Cloudflare API 密钥和域名是否正确。${plain}" SSL_HOST="${server_ip}" return 1 @@ -765,6 +969,7 @@ prompt_and_setup_ssl() { rm -rf "${certPath}" mkdir -p "${certPath}" if [ $? -ne 0 ]; then + unset CF_Key CF_Email echo -e "${red}创建目录失败:${certPath}${plain}" SSL_HOST="${server_ip}" return 1 @@ -776,6 +981,7 @@ prompt_and_setup_ssl() { --fullchain-file "${certPath}/fullchain.pem" --reloadcmd "${reloadCmd}" if [ $? -ne 0 ]; then + unset CF_Key CF_Email echo -e "${red}证书安装失败。${plain}" SSL_HOST="${server_ip}" return 1 @@ -796,20 +1002,24 @@ prompt_and_setup_ssl() { if [[ -f "$webCertFile" && -f "$webKeyFile" ]]; then ${xui_folder}/x-ui cert -webCert "$webCertFile" -webCertKey "$webKeyFile" >/dev/null 2>&1 if ! verify_panel_cert_paths "$webCertFile" "$webKeyFile"; then + unset CF_Key CF_Email SSL_HOST="${server_ip}" return 1 fi if ! save_panel_domain "$cf_domain"; then + unset CF_Key CF_Email SSL_HOST="${server_ip}" return 1 fi echo -e "${green}✓ 面板证书已设置。${plain}" else + unset CF_Key CF_Email echo -e "${red}未找到证书或私钥文件。${plain}" SSL_HOST="${server_ip}" return 1 fi + unset CF_Key CF_Email SSL_HOST="${cf_domain}" echo -e "${green}✓ Cloudflare SSL 证书配置完成。${plain}" echo -e "${yellow}注意:证书支持自动续期,无需手动管理。${plain}" @@ -819,8 +1029,11 @@ prompt_and_setup_ssl() { *) echo -e "${red}无效选项。跳过 SSL 配置。${plain}" SSL_HOST="${server_ip}" + return 1 ;; esac + + return 0 } config_after_install() { @@ -864,10 +1077,11 @@ config_after_install() { config_username="admin" fi - read -rp "请输入密码 [默认 admin]:" config_password + read -rp "请输入密码 [默认随机生成]:" config_password config_password="${config_password// /}" if [[ -z "$config_password" || "$config_password" == "rd" ]]; then - config_password="admin" + config_password=$(gen_random_string 18) + echo -e "${green}已生成随机密码:${config_password}${plain}" fi read -rp "请输入 Web 路径(不含前导 /):" config_webBasePath @@ -880,22 +1094,21 @@ config_after_install() { read -rp "是否要自定义面板端口?(否则将使用随机端口)[y/n]:" config_confirm if [[ "${config_confirm}" == "y" || "${config_confirm}" == "Y" ]]; then - read -rp "请设置面板端口:" config_port + while true; do + read -rp "请设置面板端口:" config_port + config_port="${config_port// /}" + if ! [[ "${config_port}" =~ ^[0-9]+$ ]] || ((config_port < 1 || config_port > 65535)); then + echo -e "${red}无效端口,请输入 1-65535 之间的数字。${plain}" + continue + fi + break + done echo -e "${yellow}您的面板端口为:${config_port}${plain}" else local config_port=$(shuf -i 1024-62000 -n 1) echo -e "${yellow}已生成随机端口:${config_port}${plain}" fi - ${xui_folder}/x-ui setting -username "${config_username}" -password "${config_password}" -port "${config_port}" -webBasePath "${config_webBasePath}" - local saved_port - saved_port=$(${xui_folder}/x-ui setting -show true 2>/dev/null | grep '^port:' | awk -F': ' '{print $2}' | tr -d '[:space:]') - if [[ "${saved_port}" != "${config_port}" ]]; then - echo -e "${red}端口未写入配置文件:期望 ${config_port},实际 ${saved_port:-空}${plain}" - return 1 - fi - config_port="${saved_port}" - read -rp "Database type [mariadb]: " db_type db_type=$(echo "${db_type:-mariadb}" | tr '[:upper:]' '[:lower:]' | tr -d '[:space:]') if [[ "${db_type}" != "mariadb" && "${db_type}" != "sqlite" ]]; then @@ -904,20 +1117,88 @@ config_after_install() { fi if [[ "${db_type}" == "mariadb" ]]; then - read -rp "MariaDB host [127.0.0.1]: " db_host - read -rp "MariaDB port [3306]: " db_port - read -rp "MariaDB user: " db_user - read -rsp "MariaDB password: " db_pass - echo - read -rp "MariaDB database [3xui]: " db_name + local mariadb_mode_choice mariadb_mode db_host db_port db_user db_pass db_name + read -rp "MariaDB 部署位置 [1=本地 MariaDB, 2=远程 MariaDB,默认 1]: " mariadb_mode_choice + case "${mariadb_mode_choice:-1}" in + 2) + mariadb_mode="remote" + ;; + *) + mariadb_mode="local" + ;; + esac - XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting \ - -dbHost "${db_host:-127.0.0.1}" \ - -dbPort "${db_port:-3306}" \ + if [[ "${mariadb_mode}" == "remote" ]]; then + read -rp "远程 MariaDB host [127.0.0.1]: " db_host + read -rp "远程 MariaDB port [3306]: " db_port + read -rp "业务数据库名 [3xui]: " db_name + read -rp "业务用户名: " db_user + read -rsp "业务密码: " db_pass + echo + + db_host="${db_host:-127.0.0.1}" + db_port="${db_port:-3306}" + db_name="${db_name:-3xui}" + + if [[ -z "$db_user" || -z "$db_pass" ]]; then + echo -e "${red}远程 MariaDB 的业务用户名和业务密码不能为空${plain}" + return 1 + fi + + ensure_mariadb_client_ready || { + echo -e "${red}安装 MariaDB 客户端失败${plain}" + return 1 + } + + echo -e "${green}正在验证远程 MariaDB 业务连接...${plain}" + if ! test_mariadb_database_connection "$db_host" "$db_port" "$db_name" "$db_user" "$db_pass"; then + echo -e "${red}无法使用输入的远程 MariaDB 信息连接到业务数据库${plain}" + return 1 + fi + else + db_host="127.0.0.1" + db_port="3306" + read -rp "业务数据库名 [3xui]: " db_name + read -rp "业务用户名: " db_user + read -rsp "业务密码: " db_pass + echo + + db_name="${db_name:-3xui}" + if [[ -z "$db_user" || -z "$db_pass" ]]; then + echo -e "${red}本地 MariaDB 的业务用户名和业务密码不能为空${plain}" + return 1 + fi + + ensure_local_mariadb_ready || { + echo -e "${red}准备本地 MariaDB 失败${plain}" + return 1 + } + ensure_local_mariadb_admin_access "$db_port" || return 1 + ensure_mariadb_database_and_user "$db_name" "$db_user" "$db_pass" || { + echo -e "${red}创建本地 MariaDB 业务库或业务账号失败${plain}" + return 1 + } + + echo -e "${green}正在验证本地 MariaDB 业务连接...${plain}" + if ! test_mariadb_database_connection "$db_host" "$db_port" "$db_name" "$db_user" "$db_pass"; then + echo -e "${red}无法使用创建后的本地 MariaDB 业务账号连接数据库${plain}" + return 1 + fi + fi + + if ! XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting \ + -dbType "${db_type}" \ + -dbHost "${db_host}" \ + -dbPort "${db_port}" \ -dbUser "$db_user" \ - -dbName "${db_name:-3xui}" + -dbName "${db_name}" >/dev/null 2>&1; then + echo -e "${red}写入 MariaDB 配置失败${plain}" + return 1 + fi + elif ! ${xui_folder}/x-ui setting -dbType "${db_type}" >/dev/null 2>&1; then + echo -e "${red}写入数据库类型失败${plain}" + return 1 fi - ${xui_folder}/x-ui setting -dbType "${db_type}" read -rp "Node role [master]: " node_role node_role=$(echo "${node_role:-master}" | tr '[:upper:]' '[:lower:]' | tr -d '[:space:]') @@ -926,6 +1207,22 @@ config_after_install() { node_role="master" fi + read -rp "Sync interval [30]: " sync_interval + sync_interval="${sync_interval:-30}" + while ! [[ "${sync_interval}" =~ ^[1-9][0-9]*$ ]]; do + echo -e "${yellow}同步间隔必须为正整数${plain}" + read -rp "Sync interval [30]: " sync_interval + sync_interval="${sync_interval:-30}" + done + + read -rp "Traffic flush interval [10]: " traffic_flush_interval + traffic_flush_interval="${traffic_flush_interval:-10}" + while ! [[ "${traffic_flush_interval}" =~ ^[1-9][0-9]*$ ]]; do + echo -e "${yellow}流量回刷间隔必须为正整数${plain}" + read -rp "Traffic flush interval [10]: " traffic_flush_interval + traffic_flush_interval="${traffic_flush_interval:-10}" + done + if [[ "${node_role}" == "worker" && "${db_type}" != "mariadb" ]]; then echo -e "${yellow}worker 节点要求使用 MariaDB,回退到 master${plain}" node_role="master" @@ -939,11 +1236,29 @@ config_after_install() { read -rp "Node ID: " node_id node_id="${node_id// /}" done - ${xui_folder}/x-ui setting -nodeRole worker -nodeId "$node_id" + if ! ${xui_folder}/x-ui setting -nodeRole worker -nodeId "$node_id" -syncInterval "${sync_interval}" -trafficFlushInterval "${traffic_flush_interval}" >/dev/null 2>&1; then + echo -e "${red}写入 worker 节点配置失败${plain}" + return 1 + fi else - ${xui_folder}/x-ui setting -nodeRole master + if ! ${xui_folder}/x-ui setting -nodeRole master -nodeId "" -syncInterval "${sync_interval}" -trafficFlushInterval "${traffic_flush_interval}" >/dev/null 2>&1; then + echo -e "${red}写入 master 节点配置失败${plain}" + return 1 + fi fi + if ! ${xui_folder}/x-ui setting -username "${config_username}" -password "${config_password}" -port "${config_port}" -webBasePath "${config_webBasePath}" >/dev/null 2>&1; then + echo -e "${red}写入面板基础配置失败,请检查数据库配置${plain}" + return 1 + fi + local saved_port + saved_port=$(${xui_folder}/x-ui setting -show true 2>/dev/null | grep '^port:' | awk -F': ' '{print $2}' | tr -d '[:space:]') + if [[ "${saved_port}" != "${config_port}" ]]; then + echo -e "${red}端口未写入配置文件:期望 ${config_port},实际 ${saved_port:-空}${plain}" + return 1 + fi + config_port="${saved_port}" + echo "" echo -e "${green}═══════════════════════════════════════════${plain}" echo -e "${green} SSL 证书配置(必需) ${plain}" @@ -952,7 +1267,12 @@ config_after_install() { echo -e "${yellow}Let's Encrypt 现已支持域名和 IP 地址!${plain}" echo "" - prompt_and_setup_ssl "${config_port}" "${config_webBasePath}" "${server_ip}" + if ! prompt_and_setup_ssl "${config_port}" "${config_webBasePath}" "${server_ip}"; then + echo -e "${red}SSL 配置失败,安装终止。${plain}" + return 1 + fi + local access_scheme="https" + local access_host="${SSL_HOST:-${server_ip:-localhost}}" # 显示最终凭据和访问信息 echo "" @@ -963,7 +1283,7 @@ config_after_install() { echo -e "${green}密码: ${config_password}${plain}" echo -e "${green}端口: ${config_port}${plain}" echo -e "${green}Web路径: ${config_webBasePath}${plain}" - echo -e "${green}访问地址: https://${SSL_HOST}:${config_port}/${config_webBasePath}${plain}" + echo -e "${green}访问地址: ${access_scheme}://${access_host}:${config_port}/${config_webBasePath}${plain}" echo -e "${green}═══════════════════════════════════════════${plain}" echo -e "${yellow}⚠ 重要:请安全保存这些凭据!${plain}" echo -e "${yellow}⚠ SSL 证书:已启用并配置${plain}" @@ -1075,13 +1395,17 @@ install_x-ui() { fi # 停止 x-ui 服务并删除旧资源 - if [[ -e ${xui_folder}/ ]]; then + if [[ -e "${xui_folder}/" ]]; then if [[ $release == "alpine" ]]; then rc-service x-ui stop else systemctl stop x-ui fi - rm ${xui_folder}/ -rf + if ! is_safe_install_path "${xui_folder}"; then + echo -e "${red}拒绝删除危险安装目录:${xui_folder}${plain}" + exit 1 + fi + rm -rf "${xui_folder}/" fi # 解压资源并设置权限 @@ -1103,7 +1427,7 @@ install_x-ui() { mv -f /usr/bin/x-ui-temp /usr/bin/x-ui chmod +x /usr/bin/x-ui mkdir -p /var/log/x-ui - config_after_install + config_after_install || exit 1 # Etckeeper 兼容性 if [ -d "/etc/.git" ]; then diff --git a/main.go b/main.go index 4e030470..b18908f9 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,51 @@ import ( "github.com/op/go-logging" ) +type settingCommandOptions struct { + port int + username string + password string + webBasePath string + webDomain string + listenIP string + reset bool + show bool + getListen bool + getCert bool + resetTwoFactor bool + tgbotToken string + tgbotChatID string + tgbotRuntime string + enableTgbot bool + dbType string + dbHost string + dbPort string + dbUser string + dbPassword string + dbName string + nodeRoleSet bool + nodeIDSet bool + syncIntervalSet bool + trafficFlushIntervalSet bool +} + +func (o settingCommandOptions) needsDBInit() bool { + return o.port > 0 || + o.username != "" || + o.password != "" || + o.webBasePath != "" || + o.webDomain != "" || + o.listenIP != "" || + o.show || + o.getListen || + o.getCert || + o.resetTwoFactor || + o.tgbotToken != "" || + o.tgbotChatID != "" || + o.tgbotRuntime != "" || + o.enableTgbot +} + // runWebServer initializes and starts the web server for the 3x-ui panel. func runWebServer() { log.Printf("Starting %v %v", config.GetName(), config.GetVersion()) @@ -235,12 +280,6 @@ func updateTgbotEnableSts(status bool) { // updateTgbotSetting updates Telegram bot settings including token, chat ID, and runtime schedule. func updateTgbotSetting(tgBotToken string, tgBotChatid string, tgBotRuntime string) { - err := database.InitDB() - if err != nil { - fmt.Println("Error initializing database:", err) - return - } - settingService := service.SettingService{} if tgBotToken != "" { @@ -273,12 +312,6 @@ func updateTgbotSetting(tgBotToken string, tgBotChatid string, tgBotRuntime stri // updateSetting updates various panel settings including port, domain, credentials, base path, listen IP, and two-factor authentication. func updateSetting(port int, username string, password string, webBasePath string, webDomain string, listenIP string, resetTwoFactor bool) { - err := database.InitDB() - if err != nil { - fmt.Println("Database initialization failed:", err) - return - } - settingService := service.SettingService{} userService := service.UserService{} @@ -621,26 +654,6 @@ func main() { fmt.Println(config.GetDBTypeFromJSON()) return } - if reset { - resetSetting() - } else { - updateSetting(port, username, password, webBasePath, webDomain, listenIP, resetTwoFactor) - } - if show { - showSetting(show) - } - if getListen { - GetListenIP(getListen) - } - if getCert { - GetCertificate(getCert) - } - if (tgbottoken != "") || (tgbotchatid != "") || (tgbotRuntime != "") { - updateTgbotSetting(tgbottoken, tgbotchatid, tgbotRuntime) - } - if enabletgbot { - updateTgbotEnableSts(enabletgbot) - } if dbTypeFlag != "" { if err := config.WriteSettingToJSON("dbType", dbTypeFlag); err != nil { fmt.Println("Failed to set dbType:", err) @@ -730,6 +743,59 @@ func main() { } } } + opts := settingCommandOptions{ + port: port, + username: username, + password: password, + webBasePath: webBasePath, + webDomain: webDomain, + listenIP: listenIP, + reset: reset, + show: show, + getListen: getListen, + getCert: getCert, + resetTwoFactor: resetTwoFactor, + tgbotToken: tgbottoken, + tgbotChatID: tgbotchatid, + tgbotRuntime: tgbotRuntime, + enableTgbot: enabletgbot, + dbType: dbTypeFlag, + dbHost: dbHost, + dbPort: dbPort, + dbUser: dbUser, + dbPassword: dbPassword, + dbName: dbName, + nodeRoleSet: nodeRoleSet, + nodeIDSet: nodeIDSet, + syncIntervalSet: syncIntervalSet, + trafficFlushIntervalSet: trafficFlushIntervalSet, + } + if opts.needsDBInit() { + if err := database.InitDB(); err != nil { + fmt.Println("Database initialization failed:", err) + return + } + } + if reset { + resetSetting() + } else { + updateSetting(port, username, password, webBasePath, webDomain, listenIP, resetTwoFactor) + } + if show { + showSetting(show) + } + if getListen { + GetListenIP(getListen) + } + if getCert { + GetCertificate(getCert) + } + if (tgbottoken != "") || (tgbotchatid != "") || (tgbotRuntime != "") { + updateTgbotSetting(tgbottoken, tgbotchatid, tgbotRuntime) + } + if enabletgbot { + updateTgbotEnableSts(enabletgbot) + } case "cert": err := settingCmd.Parse(os.Args[2:]) if err != nil { diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..f2542c1e --- /dev/null +++ b/main_test.go @@ -0,0 +1,60 @@ +package main + +import "testing" + +func TestSettingCommandNeedsDBInit(t *testing.T) { + tests := []struct { + name string + opts settingCommandOptions + want bool + }{ + { + name: "no flags", + opts: settingCommandOptions{}, + want: false, + }, + { + name: "read only show", + opts: settingCommandOptions{show: true}, + want: true, + }, + { + name: "port update", + opts: settingCommandOptions{port: 2053}, + want: true, + }, + { + name: "db type only", + opts: settingCommandOptions{dbType: "mariadb"}, + want: false, + }, + { + name: "db host only", + opts: settingCommandOptions{dbHost: "127.0.0.1"}, + want: false, + }, + { + name: "node role only", + opts: settingCommandOptions{nodeRoleSet: true}, + want: false, + }, + { + name: "show with db config update", + opts: settingCommandOptions{show: true, dbType: "mariadb"}, + want: true, + }, + { + name: "telegram update", + opts: settingCommandOptions{tgbotToken: "token"}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.opts.needsDBInit(); got != tt.want { + t.Fatalf("needsDBInit() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tests/mariadb_install_switch_test.sh b/tests/mariadb_install_switch_test.sh new file mode 100644 index 00000000..c975fb79 --- /dev/null +++ b/tests/mariadb_install_switch_test.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +set -euo pipefail + +assert_contains() { + local file="$1" + local pattern="$2" + if ! grep -Fq "$pattern" "$file"; then + echo "missing pattern in $file: $pattern" >&2 + return 1 + fi +} + +assert_contains "install.sh" "本地 MariaDB" +assert_contains "install.sh" "远程 MariaDB" +assert_contains "install.sh" "业务数据库名" +assert_contains "install.sh" "ensure_local_mariadb_ready" + +assert_contains "x-ui.sh" "本地 MariaDB" +assert_contains "x-ui.sh" "远程 MariaDB" +assert_contains "x-ui.sh" "ensure_local_mariadb_ready" +assert_contains "x-ui.sh" "ensure_mariadb_database_and_user" + +echo "mariadb install/switch prompts look correct" diff --git a/update.sh b/update.sh index 9a719635..28ce23cd 100755 --- a/update.sh +++ b/update.sh @@ -85,11 +85,11 @@ is_domain() { is_port_in_use() { local port="$1" if command -v ss >/dev/null 2>&1; then - ss -ltn 2>/dev/null | awk -v p=":${port}$" '$4 ~ p {exit 0} END {exit 1}' + ss -ltn 2>/dev/null | awk -v p=":${port}$" '$4 ~ p {found=1} END {exit(found ? 0 : 1)}' return fi if command -v netstat >/dev/null 2>&1; then - netstat -lnt 2>/dev/null | awk -v p=":${port} " '$4 ~ p {exit 0} END {exit 1}' + netstat -lnt 2>/dev/null | awk -v p=":${port} " '$4 ~ p {found=1} END {exit(found ? 0 : 1)}' return fi if command -v lsof >/dev/null 2>&1; then diff --git a/web/job/xray_traffic_job.go b/web/job/xray_traffic_job.go index bcf0bee0..ed6aab42 100644 --- a/web/job/xray_traffic_job.go +++ b/web/job/xray_traffic_job.go @@ -41,7 +41,7 @@ func (j *XrayTrafficJob) Run() { } needRestart0 := false if service.IsSharedModeEnabled() { - if err := j.trafficFlushSvc.Collect(clientTraffics); err != nil { + if err := j.trafficFlushSvc.Collect(traffics, clientTraffics); err != nil { logger.Warning("collect shared traffic failed:", err) } } else { diff --git a/web/service/node_sync_test.go b/web/service/node_sync_test.go index 33995d92..2cc08a1d 100644 --- a/web/service/node_sync_test.go +++ b/web/service/node_sync_test.go @@ -67,11 +67,11 @@ func TestSyncOnceSkipsApplyWhenVersionUnchanged(t *testing.T) { applyCalled := false loadSnapshotCalled := false syncSvc := &NodeSyncService{ - cachePath: cachePath, - lastSeenVersion: 9, - loadVersion: func() (int64, error) { return 9, nil }, - loadSnapshot: func() (*SharedAccountsSnapshot, error) { loadSnapshotCalled = true; return nil, nil }, - applySnapshot: func(*SharedAccountsSnapshot) error { applyCalled = true; return nil }, + cachePath: cachePath, + lastSeenVersion: 9, + loadVersion: func() (int64, error) { return 9, nil }, + loadSnapshot: func() (*SharedAccountsSnapshot, error) { loadSnapshotCalled = true; return nil, nil }, + applySnapshot: func(*SharedAccountsSnapshot) error { applyCalled = true; return nil }, } didSync, err := syncSvc.SyncOnce() @@ -122,11 +122,11 @@ func TestSyncOnceRefreshesCacheAndAppliesSnapshot(t *testing.T) { applyCalls := 0 syncSvc := &NodeSyncService{ - cachePath: cachePath, - lastSeenVersion: 11, - loadVersion: func() (int64, error) { return 12, nil }, - loadSnapshot: func() (*SharedAccountsSnapshot, error) { return wantSnapshot, nil }, - applySnapshot: func(snapshot *SharedAccountsSnapshot) error { applyCalls++; return nil }, + cachePath: cachePath, + lastSeenVersion: 11, + loadVersion: func() (int64, error) { return 12, nil }, + loadSnapshot: func() (*SharedAccountsSnapshot, error) { return wantSnapshot, nil }, + applySnapshot: func(snapshot *SharedAccountsSnapshot) error { applyCalls++; return nil }, } didSync, err := syncSvc.SyncOnce() diff --git a/web/service/traffic_flush.go b/web/service/traffic_flush.go index 7a53b282..8509717f 100644 --- a/web/service/traffic_flush.go +++ b/web/service/traffic_flush.go @@ -7,6 +7,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/config" "github.com/mhsanaei/3x-ui/v2/database" "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/logger" "github.com/mhsanaei/3x-ui/v2/util/common" "github.com/mhsanaei/3x-ui/v2/xray" "gorm.io/gorm" @@ -30,19 +31,73 @@ func NewTrafficFlushService(store *TrafficPendingStore) *TrafficFlushService { return svc } -func (s *TrafficFlushService) Collect(clientTraffics []*xray.ClientTraffic) error { - deltas := make([]TrafficDelta, 0, len(clientTraffics)) +func (s *TrafficFlushService) Collect(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) error { + deltas := make([]TrafficDelta, 0, len(clientTraffics)+len(inboundTraffics)) + clientTotals := map[int]TrafficDelta{} + for _, traffic := range clientTraffics { if traffic == nil || (traffic.Up == 0 && traffic.Down == 0) { continue } - deltas = append(deltas, TrafficDelta{ + delta := TrafficDelta{ + Kind: TrafficDeltaKindClient, InboundID: traffic.InboundId, Email: traffic.Email, UpDelta: traffic.Up, DownDelta: traffic.Down, + } + deltas = append(deltas, delta) + total := clientTotals[traffic.InboundId] + total.UpDelta += traffic.Up + total.DownDelta += traffic.Down + clientTotals[traffic.InboundId] = total + } + + for _, traffic := range inboundTraffics { + if traffic == nil || !traffic.IsInbound || (traffic.Up == 0 && traffic.Down == 0) { + continue + } + + var inbound model.Inbound + if err := database.GetDB().Select("id").First(&inbound, "tag = ?", traffic.Tag).Error; err != nil { + logger.Warning("resolve inbound tag for shared traffic failed:", err) + continue + } + + clientTotal := clientTotals[inbound.Id] + residualUp := traffic.Up - clientTotal.UpDelta + residualDown := traffic.Down - clientTotal.DownDelta + if residualUp < 0 || residualDown < 0 { + logger.Warningf( + "shared traffic residual below zero: tag=%s inbound_id=%d inbound_up=%d inbound_down=%d client_up=%d client_down=%d residual_up=%d residual_down=%d", + traffic.Tag, + inbound.Id, + traffic.Up, + traffic.Down, + clientTotal.UpDelta, + clientTotal.DownDelta, + residualUp, + residualDown, + ) + if residualUp < 0 { + residualUp = 0 + } + if residualDown < 0 { + residualDown = 0 + } + } + if residualUp == 0 && residualDown == 0 { + continue + } + + deltas = append(deltas, TrafficDelta{ + Kind: TrafficDeltaKindInboundOnly, + InboundID: inbound.Id, + UpDelta: residualUp, + DownDelta: residualDown, }) } + if len(deltas) == 0 { return nil } @@ -54,6 +109,11 @@ func (s *TrafficFlushService) flushToDatabase(deltas []TrafficDelta) error { return database.GetDB().Transaction(func(tx *gorm.DB) error { for _, delta := range deltas { + kind := delta.Kind + if kind == "" { + kind = TrafficDeltaKindClient + } + if err := tx.Model(&model.Inbound{}). Where("id = ?", delta.InboundID). Updates(map[string]any{ @@ -64,25 +124,27 @@ func (s *TrafficFlushService) flushToDatabase(deltas []TrafficDelta) error { return err } - row := xray.ClientTraffic{ - InboundId: delta.InboundID, - Email: delta.Email, - Enable: true, - Up: delta.UpDelta, - Down: delta.DownDelta, - AllTime: delta.UpDelta + delta.DownDelta, - LastOnline: now, - } - if err := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "inbound_id"}, {Name: "email"}}, - DoUpdates: clause.Assignments(map[string]any{ - "up": gorm.Expr("up + ?", delta.UpDelta), - "down": gorm.Expr("down + ?", delta.DownDelta), - "all_time": gorm.Expr("COALESCE(all_time, 0) + ?", delta.UpDelta+delta.DownDelta), - "last_online": now, - }), - }).Create(&row).Error; err != nil { - return err + if kind == TrafficDeltaKindClient { + row := xray.ClientTraffic{ + InboundId: delta.InboundID, + Email: delta.Email, + Enable: true, + Up: delta.UpDelta, + Down: delta.DownDelta, + AllTime: delta.UpDelta + delta.DownDelta, + LastOnline: now, + } + if err := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "inbound_id"}, {Name: "email"}}, + DoUpdates: clause.Assignments(map[string]any{ + "up": gorm.Expr("up + ?", delta.UpDelta), + "down": gorm.Expr("down + ?", delta.DownDelta), + "all_time": gorm.Expr("COALESCE(all_time, 0) + ?", delta.UpDelta+delta.DownDelta), + "last_online": now, + }), + }).Create(&row).Error; err != nil { + return err + } } } diff --git a/web/service/traffic_flush_test.go b/web/service/traffic_flush_test.go index 42df07a4..9f1f0c2b 100644 --- a/web/service/traffic_flush_test.go +++ b/web/service/traffic_flush_test.go @@ -3,10 +3,12 @@ package service import ( "errors" "path/filepath" + "strings" "testing" "github.com/mhsanaei/3x-ui/v2/database" "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/logger" "github.com/mhsanaei/3x-ui/v2/xray" "gorm.io/gorm" ) @@ -33,6 +35,57 @@ func TestTrafficPendingStoreMerge(t *testing.T) { } } +func TestCollectPersistsInboundOnlyDeltaFromDifference(t *testing.T) { + setupTestDB(t) + if err := database.GetDB().Create(&model.Inbound{Id: 1, Tag: "inbound-443", Enable: true}).Error; err != nil { + t.Fatalf("seed inbound failed: %v", err) + } + + store := NewTrafficPendingStore(filepath.Join(t.TempDir(), "traffic-pending.json")) + svc := NewTrafficFlushService(store) + + err := svc.Collect( + []*xray.Traffic{{Tag: "inbound-443", IsInbound: true, Up: 100, Down: 50}}, + []*xray.ClientTraffic{{InboundId: 1, Email: "alice@example.com", Up: 70, Down: 20}}, + ) + if err != nil { + t.Fatalf("Collect error: %v", err) + } + + deltas, err := store.Load() + if err != nil { + t.Fatalf("Load error: %v", err) + } + if len(deltas) != 2 { + t.Fatalf("expected 2 deltas, got %d", len(deltas)) + } + + var clientDelta *TrafficDelta + var inboundOnlyDelta *TrafficDelta + for i := range deltas { + switch deltas[i].Kind { + case TrafficDeltaKindClient: + clientDelta = &deltas[i] + case TrafficDeltaKindInboundOnly: + inboundOnlyDelta = &deltas[i] + } + } + + if clientDelta == nil { + t.Fatal("expected client delta to be persisted") + } + if clientDelta.InboundID != 1 || clientDelta.Email != "alice@example.com" || clientDelta.UpDelta != 70 || clientDelta.DownDelta != 20 { + t.Fatalf("unexpected client delta: %+v", *clientDelta) + } + + if inboundOnlyDelta == nil { + t.Fatal("expected inbound-only delta to be persisted") + } + if inboundOnlyDelta.InboundID != 1 || inboundOnlyDelta.Email != "" || inboundOnlyDelta.UpDelta != 30 || inboundOnlyDelta.DownDelta != 30 { + t.Fatalf("unexpected inbound-only delta: %+v", *inboundOnlyDelta) + } +} + func TestFlushOnceClearsPendingOnSuccess(t *testing.T) { setupTestDB(t) @@ -70,6 +123,101 @@ func TestFlushOnceClearsPendingOnSuccess(t *testing.T) { } } +func TestFlushOnceAppliesInboundOnlyDeltaWithoutCreatingClientTraffic(t *testing.T) { + setupTestDB(t) + + if err := database.GetDB().Create(&model.Inbound{Id: 1, Tag: "inbound-443", Enable: true}).Error; err != nil { + t.Fatalf("seed inbound failed: %v", err) + } + if err := database.GetDB().Create(&xray.ClientTraffic{InboundId: 1, Email: "alice@example.com", Enable: true}).Error; err != nil { + t.Fatalf("seed client traffic failed: %v", err) + } + + store := NewTrafficPendingStore(filepath.Join(t.TempDir(), "traffic-pending.json")) + if err := store.Merge([]TrafficDelta{ + {Kind: TrafficDeltaKindClient, InboundID: 1, Email: "alice@example.com", UpDelta: 7, DownDelta: 9}, + {Kind: TrafficDeltaKindInboundOnly, InboundID: 1, UpDelta: 3, DownDelta: 4}, + }); err != nil { + t.Fatalf("Merge error: %v", err) + } + + svc := NewTrafficFlushService(store) + if err := svc.FlushOnce(); err != nil { + t.Fatalf("FlushOnce error: %v", err) + } + + var inbound model.Inbound + if err := database.GetDB().First(&inbound, "id = ?", 1).Error; err != nil { + t.Fatalf("lookup inbound failed: %v", err) + } + if inbound.Up != 10 || inbound.Down != 13 || inbound.AllTime != 23 { + t.Fatalf("unexpected inbound totals: %+v", inbound) + } + + var clientTraffic xray.ClientTraffic + if err := database.GetDB().First(&clientTraffic, "inbound_id = ? AND email = ?", 1, "alice@example.com").Error; err != nil { + t.Fatalf("lookup client traffic failed: %v", err) + } + if clientTraffic.Up != 7 || clientTraffic.Down != 9 { + t.Fatalf("unexpected flushed client traffic: %+v", clientTraffic) + } + + var count int64 + if err := database.GetDB().Model(&xray.ClientTraffic{}).Where("inbound_id = ? AND email = ?", 1, "").Count(&count).Error; err != nil { + t.Fatalf("count inbound-only client rows failed: %v", err) + } + if count != 0 { + t.Fatalf("expected no client_traffics row for inbound-only delta, got %d", count) + } +} + +func TestCollectClampsNegativeResidualAndLogsDetailedWarning(t *testing.T) { + setupTestDB(t) + if err := database.GetDB().Create(&model.Inbound{Id: 1, Tag: "inbound-443", Enable: true}).Error; err != nil { + t.Fatalf("seed inbound failed: %v", err) + } + + store := NewTrafficPendingStore(filepath.Join(t.TempDir(), "traffic-pending.json")) + svc := NewTrafficFlushService(store) + + err := svc.Collect( + []*xray.Traffic{{Tag: "inbound-443", IsInbound: true, Up: 10, Down: 5}}, + []*xray.ClientTraffic{{InboundId: 1, Email: "alice@example.com", Up: 12, Down: 7}}, + ) + if err != nil { + t.Fatalf("Collect error: %v", err) + } + + deltas, err := store.Load() + if err != nil { + t.Fatalf("Load error: %v", err) + } + if len(deltas) != 1 { + t.Fatalf("expected only client delta after clamping negative residual, got %d deltas: %+v", len(deltas), deltas) + } + if deltas[0].Kind != TrafficDeltaKindClient { + t.Fatalf("expected remaining delta to be client kind, got %+v", deltas[0]) + } + + logs := logger.GetLogs(200, "WARNING") + joined := strings.Join(logs, "\n") + for _, want := range []string{ + "shared traffic residual below zero", + "tag=inbound-443", + "inbound_id=1", + "inbound_up=10", + "inbound_down=5", + "client_up=12", + "client_down=7", + "residual_up=-2", + "residual_down=-2", + } { + if !strings.Contains(joined, want) { + t.Fatalf("expected warning logs to contain %q, logs=%s", want, joined) + } + } +} + func TestFlushOnceKeepsPendingOnFailure(t *testing.T) { store := NewTrafficPendingStore(filepath.Join(t.TempDir(), "traffic-pending.json")) if err := store.Merge([]TrafficDelta{{InboundID: 1, Email: "alice@example.com", UpDelta: 3}}); err != nil { diff --git a/web/service/traffic_pending.go b/web/service/traffic_pending.go index 85167424..ac83c789 100644 --- a/web/service/traffic_pending.go +++ b/web/service/traffic_pending.go @@ -9,12 +9,18 @@ import ( ) type TrafficDelta struct { + Kind string `json:"kind"` InboundID int `json:"inboundId"` Email string `json:"email"` UpDelta int64 `json:"upDelta"` DownDelta int64 `json:"downDelta"` } +const ( + TrafficDeltaKindClient = "client" + TrafficDeltaKindInboundOnly = "inbound_only" +) + type TrafficPendingStore struct { path string mu sync.Mutex @@ -47,11 +53,14 @@ func (s *TrafficPendingStore) Merge(newDeltas []TrafficDelta) error { index := map[string]int{} for i, delta := range current { - index[deltaKey(delta.InboundID, delta.Email)] = i + index[deltaKey(delta.Kind, delta.InboundID, delta.Email)] = i } for _, delta := range newDeltas { - key := deltaKey(delta.InboundID, delta.Email) + if delta.Kind == "" { + delta.Kind = TrafficDeltaKindClient + } + key := deltaKey(delta.Kind, delta.InboundID, delta.Email) if idx, ok := index[key]; ok { current[idx].UpDelta += delta.UpDelta current[idx].DownDelta += delta.DownDelta @@ -108,6 +117,6 @@ func (s *TrafficPendingStore) saveUnlocked(deltas []TrafficDelta) error { return os.WriteFile(s.path, data, 0644) } -func deltaKey(inboundID int, email string) string { - return fmt.Sprintf("%d:%s", inboundID, email) +func deltaKey(kind string, inboundID int, email string) string { + return fmt.Sprintf("%s:%d:%s", kind, inboundID, email) } diff --git a/x-ui.sh b/x-ui.sh index 789b9ada..b21b1a26 100644 --- a/x-ui.sh +++ b/x-ui.sh @@ -23,11 +23,11 @@ function LOGI() { is_port_in_use() { local port="$1" if command -v ss >/dev/null 2>&1; then - ss -ltn 2>/dev/null | awk -v p=":${port}$" '$4 ~ p {exit 0} END {exit 1}' + ss -ltn 2>/dev/null | awk -v p=":${port}$" '$4 ~ p {found=1} END {exit(found ? 0 : 1)}' return fi if command -v netstat >/dev/null 2>&1; then - netstat -lnt 2>/dev/null | awk -v p=":${port} " '$4 ~ p {exit 0} END {exit 1}' + netstat -lnt 2>/dev/null | awk -v p=":${port} " '$4 ~ p {found=1} END {exit(found ? 0 : 1)}' return fi if command -v lsof >/dev/null 2>&1; then @@ -2238,6 +2238,7 @@ get_node_setting() { local key="$1" local default_value="$2" local json_path="/etc/x-ui/x-ui.json" + local jq_expr="" if [ ! -f "$json_path" ]; then echo "$default_value" @@ -2245,7 +2246,24 @@ get_node_setting() { fi if command -v jq >/dev/null 2>&1; then - jq -r "$key // $default_value" "$json_path" 2>/dev/null + case "$key" in + ".nodeRole") + jq_expr='.other.nodeRole // .nodeRole // "master"' + ;; + ".nodeId") + jq_expr='.other.nodeId // .nodeId // ""' + ;; + ".syncInterval") + jq_expr='.other.syncInterval // .syncInterval // "30"' + ;; + ".trafficFlushInterval") + jq_expr='.other.trafficFlushInterval // .trafficFlushInterval // "10"' + ;; + *) + jq_expr="$key // $default_value" + ;; + esac + jq -r "$jq_expr" "$json_path" 2>/dev/null return fi @@ -2335,20 +2353,99 @@ set_node_id() { echo -e "${yellow}节点 ID 已更新,建议重启面板使其完全生效。${plain}" } -# Check if MariaDB is installed (server or client) -check_mariadb_installed() { - if command -v mariadb >/dev/null 2>&1 || command -v mysql >/dev/null 2>&1; then +set_sync_interval() { + local sync_interval="" + + read -rp "输入同步间隔(秒): " sync_interval + sync_interval="${sync_interval// /}" + if ! [[ "${sync_interval}" =~ ^[1-9][0-9]*$ ]]; then + echo -e "${red}同步间隔必须为正整数${plain}" + return 1 + fi + if ! ${xui_folder}/x-ui setting -syncInterval "${sync_interval}"; then + echo -e "${red}同步间隔更新失败${plain}" + return 1 + fi + echo -e "${yellow}同步间隔已更新,建议重启面板使其完全生效。${plain}" +} + +set_traffic_flush_interval() { + local flush_interval="" + + read -rp "输入流量回刷间隔(秒): " flush_interval + flush_interval="${flush_interval// /}" + if ! [[ "${flush_interval}" =~ ^[1-9][0-9]*$ ]]; then + echo -e "${red}流量回刷间隔必须为正整数${plain}" + return 1 + fi + if ! ${xui_folder}/x-ui setting -trafficFlushInterval "${flush_interval}"; then + echo -e "${red}流量回刷间隔更新失败${plain}" + return 1 + fi + echo -e "${yellow}流量回刷间隔已更新,建议重启面板使其完全生效。${plain}" +} + +has_mariadb_cli() { + command -v mariadb >/dev/null 2>&1 || command -v mysql >/dev/null 2>&1 +} + +mariadb_cli_bin() { + if command -v mariadb >/dev/null 2>&1; then + command -v mariadb return 0 fi - if systemctl is-active --quiet mariadb 2>/dev/null || systemctl is-active --quiet mysql 2>/dev/null; then + if command -v mysql >/dev/null 2>&1; then + command -v mysql return 0 fi return 1 } -# Install MariaDB server based on distro -install_mariadb() { - echo -e "${green}正在安装 MariaDB...${plain}" +has_local_mariadb_service() { + if command -v systemctl >/dev/null 2>&1; then + systemctl list-unit-files 2>/dev/null | grep -qE '^(mariadb|mysql)\.service$' && return 0 + fi + [[ -f /etc/init.d/mariadb ]] +} + +check_mariadb_installed() { + has_mariadb_cli || has_local_mariadb_service +} + +install_mariadb_client() { + echo -e "${green}正在安装 MariaDB 客户端...${plain}" + case "${release}" in + ubuntu | debian | linuxmint) + apt-get update -y && apt-get install -y mariadb-client + ;; + centos | rhel | almalinux | rocky | ol | alinux | amzn) + if command -v dnf >/dev/null 2>&1; then + dnf install -y mariadb + else + yum install -y mariadb + fi + ;; + fedora) + dnf install -y mariadb + ;; + arch | manjaro) + pacman -Sy --noconfirm mariadb-clients >/dev/null 2>&1 || pacman -Sy --noconfirm mariadb + ;; + opensuse* | sles) + zypper install -y mariadb-client + ;; + alpine) + apk add mariadb-client + ;; + *) + echo -e "${red}不支持的发行版: ${release},请手动安装 MariaDB 客户端${plain}" + return 1 + ;; + esac +} + +install_local_mariadb_server() { + echo -e "${green}正在安装本地 MariaDB...${plain}" case "${release}" in ubuntu | debian | linuxmint) apt-get update -y && apt-get install -y mariadb-server mariadb-client @@ -2365,14 +2462,14 @@ install_mariadb() { ;; arch | manjaro) pacman -Sy --noconfirm mariadb - mariadb-install-db --user=mysql --basedir=/usr --datadir=/var/lib/mysql >/dev/null 2>&1 + mariadb-install-db --user=mysql --basedir=/usr --datadir=/var/lib/mysql >/dev/null 2>&1 || true ;; opensuse* | sles) zypper install -y mariadb-server mariadb-client ;; alpine) apk add mariadb mariadb-client - mariadb-install-db --user=mysql --basedir=/usr --datadir=/var/lib/mysql >/dev/null 2>&1 + mariadb-install-db --user=mysql --basedir=/usr --datadir=/var/lib/mysql >/dev/null 2>&1 || true ;; *) echo -e "${red}不支持的发行版: ${release},请手动安装 MariaDB${plain}" @@ -2388,55 +2485,153 @@ install_mariadb() { return $ret } -# Start and enable MariaDB service start_mariadb_service() { local svc_name="" - if systemctl list-unit-files | grep -q "^mariadb.service"; then - svc_name="mariadb" - elif systemctl list-unit-files | grep -q "^mysql.service"; then - svc_name="mysql" + if command -v systemctl >/dev/null 2>&1; then + if systemctl list-unit-files 2>/dev/null | grep -q "^mariadb.service"; then + svc_name="mariadb" + elif systemctl list-unit-files 2>/dev/null | grep -q "^mysql.service"; then + svc_name="mysql" + fi fi if [ -n "$svc_name" ]; then systemctl start "$svc_name" 2>/dev/null systemctl enable "$svc_name" 2>/dev/null - else - # alpine / no systemd + return 0 + fi + if [[ $release == "alpine" ]]; then rc-service mariadb start 2>/dev/null rc-update add mariadb 2>/dev/null + return $? fi + return 1 } -# Test MariaDB connection with given credentials -# Args: host port user pass -test_mariadb_connection() { +ensure_mariadb_client_ready() { + if has_mariadb_cli; then + return 0 + fi + echo -e "${yellow}未检测到 MariaDB 客户端${plain}" + confirm "是否安装 MariaDB 客户端?" "y" || return 1 + install_mariadb_client || return 1 + has_mariadb_cli +} + +ensure_local_mariadb_ready() { + if ! has_local_mariadb_service; then + echo -e "${yellow}未检测到本地 MariaDB 服务${plain}" + confirm "是否安装本地 MariaDB?" "y" || return 1 + install_local_mariadb_server || return 1 + fi + ensure_mariadb_client_ready || return 1 + start_mariadb_service || true + return 0 +} + +test_mariadb_server_connection() { local host="$1" port="$2" user="$3" pass="$4" - mariadb -h "$host" -P "$port" -u "$user" -p"$pass" -e "SELECT 1;" >/dev/null 2>&1 - return $? + local bin + bin=$(mariadb_cli_bin) || return 1 + "$bin" -h "$host" -P "$port" -u "$user" -p"$pass" -e "SELECT 1;" >/dev/null 2>&1 } -# Test connection to a specific database, create if not exists -# Args: host port user pass dbname -ensure_mariadb_database() { - local host="$1" port="$2" user="$3" pass="$4" dbname="$5" - # Check if database exists - local exists - exists=$(mariadb -h "$host" -P "$port" -u "$user" -p"$pass" -N -B -e \ - "SELECT COUNT(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME='${dbname}';" 2>/dev/null) - if [ "$exists" = "1" ]; then - echo -e "${green}数据库 '${dbname}' 已存在${plain}" +test_mariadb_database_connection() { + local host="$1" port="$2" dbname="$3" user="$4" pass="$5" + local bin + bin=$(mariadb_cli_bin) || return 1 + "$bin" -h "$host" -P "$port" -u "$user" -p"$pass" -D "$dbname" -e "SELECT 1;" >/dev/null 2>&1 +} + +is_safe_mariadb_identifier() { + [[ "$1" =~ ^[A-Za-z0-9_.-]+$ ]] +} + +escape_sql_string() { + printf "%s" "$1" | sed "s/'/''/g" +} + +LOCAL_MARIADB_ADMIN_MODE="" +LOCAL_MARIADB_ADMIN_USER="" +LOCAL_MARIADB_ADMIN_PASS="" +LOCAL_MARIADB_ADMIN_PORT="3306" + +try_local_mariadb_socket_admin() { + local bin + bin=$(mariadb_cli_bin) || return 1 + "$bin" -e "SELECT 1;" >/dev/null 2>&1 || "$bin" -uroot -e "SELECT 1;" >/dev/null 2>&1 +} + +ensure_local_mariadb_admin_access() { + local port="${1:-3306}" + LOCAL_MARIADB_ADMIN_PORT="$port" + + if try_local_mariadb_socket_admin; then + LOCAL_MARIADB_ADMIN_MODE="socket" return 0 fi - # Create database - echo -e "${green}正在创建数据库 '${dbname}'...${plain}" - mariadb -h "$host" -P "$port" -u "$user" -p"$pass" -e \ - "CREATE DATABASE \`${dbname}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;" 2>/dev/null - if [ $? -eq 0 ]; then - echo -e "${green}数据库 '${dbname}' 创建成功${plain}" - return 0 - else - echo -e "${red}数据库 '${dbname}' 创建失败${plain}" + + local admin_user admin_pass + echo -e "${yellow}无法通过 root socket 直接连接本地 MariaDB,请输入管理员账号信息。${plain}" + read -rp "MariaDB 管理员用户名 [root]: " admin_user + admin_user="${admin_user:-root}" + read -rsp "MariaDB 管理员密码: " admin_pass + echo + + if ! test_mariadb_server_connection "127.0.0.1" "$port" "$admin_user" "$admin_pass"; then + echo -e "${red}管理员账号连接失败${plain}" return 1 fi + + LOCAL_MARIADB_ADMIN_MODE="password" + LOCAL_MARIADB_ADMIN_USER="$admin_user" + LOCAL_MARIADB_ADMIN_PASS="$admin_pass" +} + +run_local_mariadb_admin_sql() { + local sql="$1" + local bin + bin=$(mariadb_cli_bin) || return 1 + + case "$LOCAL_MARIADB_ADMIN_MODE" in + socket) + "$bin" -e "$sql" >/dev/null 2>&1 || "$bin" -uroot -e "$sql" >/dev/null 2>&1 + ;; + password) + "$bin" -h "127.0.0.1" -P "$LOCAL_MARIADB_ADMIN_PORT" -u "$LOCAL_MARIADB_ADMIN_USER" -p"$LOCAL_MARIADB_ADMIN_PASS" -e "$sql" >/dev/null 2>&1 + ;; + *) + return 1 + ;; + esac +} + +ensure_mariadb_database_and_user() { + local dbname="$1" dbuser="$2" dbpass="$3" + local escaped_pass + local sql="" + local account_host="" + + if ! is_safe_mariadb_identifier "$dbname"; then + echo -e "${red}业务数据库名仅支持字母、数字、点、下划线和连字符${plain}" + return 1 + fi + if ! is_safe_mariadb_identifier "$dbuser"; then + echo -e "${red}业务用户名仅支持字母、数字、点、下划线和连字符${plain}" + return 1 + fi + + escaped_pass=$(escape_sql_string "$dbpass") + sql="CREATE DATABASE IF NOT EXISTS \`${dbname}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;" + + for account_host in "localhost" "127.0.0.1" "::1"; do + sql="${sql} CREATE USER IF NOT EXISTS '${dbuser}'@'${account_host}' IDENTIFIED BY '${escaped_pass}';" + sql="${sql} ALTER USER '${dbuser}'@'${account_host}' IDENTIFIED BY '${escaped_pass}';" + sql="${sql} GRANT ALL PRIVILEGES ON \`${dbname}\`.* TO '${dbuser}'@'${account_host}';" + done + sql="${sql} FLUSH PRIVILEGES;" + + echo -e "${green}正在确保本地 MariaDB 的业务库和业务账号存在...${plain}" + run_local_mariadb_admin_sql "$sql" } # Switch to MariaDB @@ -2448,77 +2643,86 @@ db_switch_to_mariadb() { return fi - # Step 1: Check MariaDB installation - if ! check_mariadb_installed; then - echo -e "${yellow}未检测到 MariaDB${plain}" - confirm "是否安装 MariaDB?" "y" - if [ $? -ne 0 ]; then - echo -e "${yellow}已取消安装,返回数据库菜单${plain}" + local mariadb_mode_choice mariadb_mode + local db_host db_port db_user db_pass db_name + + read -rp "MariaDB 部署位置 [1=本地 MariaDB, 2=远程 MariaDB,默认 1]: " mariadb_mode_choice + case "${mariadb_mode_choice:-1}" in + 2) + mariadb_mode="remote" + ;; + *) + mariadb_mode="local" + ;; + esac + + if [[ "${mariadb_mode}" == "remote" ]]; then + ensure_mariadb_client_ready || { + echo -e "${yellow}已取消安装 MariaDB 客户端,返回数据库菜单${plain}" + db_menu + return + } + + echo -e "${green}请输入远程 MariaDB 业务连接信息(直接回车使用默认值):${plain}" + read -rp "远程 MariaDB host [127.0.0.1]: " db_host + read -rp "远程 MariaDB port [3306]: " db_port + read -rp "业务数据库名 [3xui]: " db_name + read -rp "业务用户名: " db_user + read -rsp "业务密码: " db_pass + echo + + db_host=${db_host:-127.0.0.1} + db_port=${db_port:-3306} + db_name=${db_name:-3xui} + if [[ -z "$db_user" || -z "$db_pass" ]]; then + echo -e "${red}业务用户名和业务密码不能为空${plain}" db_menu return fi - install_mariadb - if [ $? -ne 0 ]; then - echo -e "${red}MariaDB 安装失败,返回数据库菜单${plain}" + + echo -e "${green}正在验证远程 MariaDB 业务连接...${plain}" + if ! test_mariadb_database_connection "$db_host" "$db_port" "$db_name" "$db_user" "$db_pass"; then + echo -e "${red}无法使用输入的远程 MariaDB 信息连接到业务数据库${plain}" db_menu return fi - start_mariadb_service - if ! check_mariadb_installed; then - echo -e "${red}MariaDB 安装后仍无法检测到,请手动检查${plain}" - db_menu - return - fi - echo -e "${green}MariaDB 已安装并启动${plain}" else - echo -e "${green}MariaDB 已安装${plain}" - # Ensure service is running - start_mariadb_service + db_host="127.0.0.1" + db_port="3306" + read -rp "业务数据库名 [3xui]: " db_name + read -rp "业务用户名: " db_user + read -rsp "业务密码: " db_pass + echo + + db_name=${db_name:-3xui} + if [[ -z "$db_user" || -z "$db_pass" ]]; then + echo -e "${red}业务用户名和业务密码不能为空${plain}" + db_menu + return + fi + + ensure_local_mariadb_ready || { + echo -e "${yellow}本地 MariaDB 未准备完成,返回数据库菜单${plain}" + db_menu + return + } + ensure_local_mariadb_admin_access "$db_port" || { + db_menu + return + } + ensure_mariadb_database_and_user "$db_name" "$db_user" "$db_pass" || { + db_menu + return + } + + echo -e "${green}正在验证本地 MariaDB 业务连接...${plain}" + if ! test_mariadb_database_connection "$db_host" "$db_port" "$db_name" "$db_user" "$db_pass"; then + echo -e "${red}无法使用创建后的本地 MariaDB 业务账号连接数据库${plain}" + db_menu + return + fi fi - # Step 2: Collect connection info - echo -e "${green}请输入 MariaDB 连接信息(直接回车使用默认值):${plain}" - - read -rp "MariaDB IP(默认 127.0.0.1): " db_host - db_host=${db_host:-127.0.0.1} - - read -rp "MariaDB 端口(默认 3306): " db_port - db_port=${db_port:-3306} - - read -rp "MariaDB 用户名: " db_user - if [ -z "$db_user" ]; then - echo -e "${red}用户名不能为空${plain}" - db_menu - return - fi - - read -rsp "MariaDB 密码: " db_pass - echo - if [ -z "$db_pass" ]; then - echo -e "${red}密码不能为空${plain}" - db_menu - return - fi - - read -rp "数据库名(默认 3xui): " db_name - db_name=${db_name:-3xui} - - # Step 3: Test connection - echo -e "${green}正在测试数据库连接...${plain}" - if ! test_mariadb_connection "$db_host" "$db_port" "$db_user" "$db_pass"; then - echo -e "${red}无法连接到 MariaDB,请检查用户名、密码及主机信息${plain}" - db_menu - return - fi - echo -e "${green}数据库连接成功${plain}" - - # Step 4: Ensure database exists - if ! ensure_mariadb_database "$db_host" "$db_port" "$db_user" "$db_pass" "$db_name"; then - db_menu - return - fi - - # Step 5: Save config and migrate echo -e "${green}正在配置 MariaDB 连接...${plain}" XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbName "$db_name" >/dev/null 2>&1 @@ -2572,9 +2776,11 @@ db_menu() { │ ${green}4.${plain} 查看当前节点设置 │ │ ${green}5.${plain} 设置节点角色 │ │ ${green}6.${plain} 设置节点 ID │ +│ ${green}7.${plain} 设置同步间隔 │ +│ ${green}8.${plain} 设置流量回刷间隔 │ ╚════════════════════════════════════════════════╝ " - read -rp "请输入选择 [0-6]:" num + read -rp "请输入选择 [0-8]:" num case "${num}" in 0) show_menu @@ -2601,6 +2807,14 @@ db_menu() { set_node_id db_menu ;; + 7) + set_sync_interval + db_menu + ;; + 8) + set_traffic_flush_interval + db_menu + ;; *) echo -e "${red}无效选项,请选择有效数字。${plain}\n" db_menu