diff --git a/database/db.go b/database/db.go index d5caa702..97ef2160 100644 --- a/database/db.go +++ b/database/db.go @@ -17,7 +17,16 @@ import ( var db *gorm.DB +var initializers = []func() error{ + initUser, + initInbound, + initSetting, + initInboundClientIps, + initClientTraffic, +} + func initUser() error { + err := db.AutoMigrate(&model.User{}) if err != nil { return err @@ -54,7 +63,7 @@ func initClientTraffic() error { func InitDB(dbPath string) error { dir := path.Dir(dbPath) - err := os.MkdirAll(dir, fs.ModeDir) + err := os.MkdirAll(dir, fs.ModePerm) if err != nil { return err } @@ -75,25 +84,10 @@ func InitDB(dbPath string) error { return err } - err = initUser() - if err != nil { - return err - } - err = initInbound() - if err != nil { - return err - } - err = initSetting() - if err != nil { - return err - } - err = initInboundClientIps() - if err != nil { - return err - } - err = initClientTraffic() - if err != nil { - return err + for _, initialize := range initializers { + if err := initialize(); err != nil { + return err + } } return nil @@ -107,10 +101,10 @@ func IsNotFound(err error) bool { return err == gorm.ErrRecordNotFound } -func IsSQLiteDB(file io.Reader) (bool, error) { +func IsSQLiteDB(file io.ReaderAt) (bool, error) { signature := []byte("SQLite format 3\x00") buf := make([]byte, len(signature)) - _, err := file.Read(buf) + _, err := file.ReadAt(buf, 0) if err != nil { return false, err } diff --git a/logger/logger.go b/logger/logger.go index 405f2ec7..f5b0a878 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,17 +2,28 @@ package logger import ( "os" + "sync" "github.com/op/go-logging" ) -var logger *logging.Logger +var ( + logger *logging.Logger + mu sync.Mutex +) func init() { InitLogger(logging.INFO) } func InitLogger(level logging.Level) { + mu.Lock() + defer mu.Unlock() + + if logger != nil { + return + } + format := logging.MustStringFormatter( `%{time:2006/01/02 15:04:05} %{level} - %{message}`, ) @@ -21,39 +32,55 @@ func InitLogger(level logging.Level) { backendFormatter := logging.NewBackendFormatter(backend, format) backendLeveled := logging.AddModuleLevel(backendFormatter) backendLeveled.SetLevel(level, "") - newLogger.SetBackend(backendLeveled) + newLogger.SetBackend(logging.MultiLogger(backendLeveled)) logger = newLogger } func Debug(args ...interface{}) { - logger.Debug(args...) + if logger != nil { + logger.Debug(args...) + } } func Debugf(format string, args ...interface{}) { - logger.Debugf(format, args...) + if logger != nil { + logger.Debugf(format, args...) + } } func Info(args ...interface{}) { - logger.Info(args...) + if logger != nil { + logger.Info(args...) + } } func Infof(format string, args ...interface{}) { - logger.Infof(format, args...) + if logger != nil { + logger.Infof(format, args...) + } } func Warning(args ...interface{}) { - logger.Warning(args...) + if logger != nil { + logger.Warning(args...) + } } func Warningf(format string, args ...interface{}) { - logger.Warningf(format, args...) + if logger != nil { + logger.Warningf(format, args...) + } } func Error(args ...interface{}) { - logger.Error(args...) + if logger != nil { + logger.Error(args...) + } } func Errorf(format string, args ...interface{}) { - logger.Errorf(format, args...) + if logger != nil { + logger.Errorf(format, args...) + } } diff --git a/util/sys/sys_linux.go b/util/sys/sys_linux.go index bcb7046e..097ff7c7 100644 --- a/util/sys/sys_linux.go +++ b/util/sys/sys_linux.go @@ -24,8 +24,8 @@ func getLinesNum(filename string) (int, error) { var buffPosition int for { - i := bytes.IndexByte(buf[buffPosition:], '\n') - if i < 0 || n == buffPosition { + i := bytes.IndexByte(buf[buffPosition:n], '\n') + if i < 0 { break } buffPosition += i + 1 @@ -33,11 +33,12 @@ func getLinesNum(filename string) (int, error) { } if err == io.EOF { - return sum, nil + break } else if err != nil { - return sum, err + return 0, err } } + return sum, nil } func GetTCPCount() (int, error) { @@ -45,11 +46,11 @@ func GetTCPCount() (int, error) { tcp4, err := getLinesNum(fmt.Sprintf("%v/net/tcp", root)) if err != nil { - return tcp4, err + return 0, err } tcp6, err := getLinesNum(fmt.Sprintf("%v/net/tcp6", root)) if err != nil { - return tcp4 + tcp6, nil + return 0, err } return tcp4 + tcp6, nil @@ -60,11 +61,11 @@ func GetUDPCount() (int, error) { udp4, err := getLinesNum(fmt.Sprintf("%v/net/udp", root)) if err != nil { - return udp4, err + return 0, err } udp6, err := getLinesNum(fmt.Sprintf("%v/net/udp6", root)) if err != nil { - return udp4 + udp6, nil + return 0, err } return udp4 + udp6, nil diff --git a/util/sys/sys_windows.go b/util/sys/sys_windows.go index ef9cdf54..34740ea6 100644 --- a/util/sys/sys_windows.go +++ b/util/sys/sys_windows.go @@ -4,21 +4,27 @@ package sys import ( + "errors" + "github.com/shirou/gopsutil/v3/net" ) -func GetTCPCount() (int, error) { - stats, err := net.Connections("tcp") +func GetConnectionCount(proto string) (int, error) { + if proto != "tcp" && proto != "udp" { + return 0, errors.New("invalid protocol") + } + + stats, err := net.Connections(proto) if err != nil { return 0, err } return len(stats), nil } -func GetUDPCount() (int, error) { - stats, err := net.Connections("udp") - if err != nil { - return 0, err - } - return len(stats), nil +func GetTCPCount() (int, error) { + return GetConnectionCount("tcp") +} + +func GetUDPCount() (int, error) { + return GetConnectionCount("udp") } diff --git a/x-ui.sh b/x-ui.sh index d8d880c5..3592ca63 100644 --- a/x-ui.sh +++ b/x-ui.sh @@ -644,27 +644,37 @@ warp_fixchatgpt() { run_speedtest() { # Check if Speedtest is already installed - if ! command -v speedtest &>/dev/null; then + if ! command -v speedtest &> /dev/null; then # If not installed, install it - if command -v dnf &>/dev/null; then - sudo dnf install -y curl - curl -s https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.rpm.sh | sudo bash - sudo dnf install -y speedtest - elif command -v yum &>/dev/null; then - sudo yum install -y curl - curl -s https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.rpm.sh | sudo bash - sudo yum install -y speedtest - elif command -v apt-get &>/dev/null; then - sudo apt-get update && sudo apt-get install -y curl - curl -s https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.deb.sh | sudo bash - sudo apt-get install -y speedtest - elif command -v apt &>/dev/null; then - sudo apt update && sudo apt install -y curl - curl -s https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.deb.sh | sudo bash - sudo apt install -y speedtest - else + local pkg_manager="" + local curl_install_cmd="" + local speedtest_install_script="" + + if command -v dnf &> /dev/null; then + pkg_manager="dnf" + curl_install_cmd="sudo dnf install -y curl" + speedtest_install_script="https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.rpm.sh" + elif command -v yum &> /dev/null; then + pkg_manager="yum" + curl_install_cmd="sudo yum install -y curl" + speedtest_install_script="https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.rpm.sh" + elif command -v apt-get &> /dev/null; then + pkg_manager="apt-get" + curl_install_cmd="sudo apt-get update && sudo apt-get install -y curl" + speedtest_install_script="https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.deb.sh" + elif command -v apt &> /dev/null; then + pkg_manager="apt" + curl_install_cmd="sudo apt update && sudo apt install -y curl" + speedtest_install_script="https://packagecloud.io/install/repositories/ookla/speedtest-cli/script.deb.sh" + fi + + if [[ -z $pkg_manager ]]; then echo "Error: Package manager not found. You may need to install Speedtest manually." return 1 + else + $curl_install_cmd + curl -s $speedtest_install_script | sudo bash + sudo $pkg_manager install -y speedtest fi fi diff --git a/xray/process.go b/xray/process.go index 3ce016ed..3f6bf294 100644 --- a/xray/process.go +++ b/xray/process.go @@ -13,6 +13,7 @@ import ( "regexp" "runtime" "strings" + "sync" "time" "x-ui/config" "x-ui/util/common" @@ -184,11 +185,11 @@ func (p *process) Start() (err error) { return err } + var wg sync.WaitGroup + wg.Add(2) + go func() { - defer func() { - common.Recover("") - stdReader.Close() - }() + defer wg.Done() reader := bufio.NewReaderSize(stdReader, 8192) for { line, _, err := reader.ReadLine() @@ -203,10 +204,7 @@ func (p *process) Start() (err error) { }() go func() { - defer func() { - common.Recover("") - errReader.Close() - }() + defer wg.Done() reader := bufio.NewReaderSize(errReader, 8192) for { line, _, err := reader.ReadLine() @@ -225,6 +223,7 @@ func (p *process) Start() (err error) { if err != nil { p.exitErr = err } + wg.Wait() }() p.refreshVersion()