mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-05 20:54:14 +00:00
fix: address copilot review on #4545
Two issues raised by the Copilot review: 1) subscriptionExpiryFromClient called time.Now() per invocation. Two clients with the same delayed-start duration normalized to timestamps a few milliseconds apart, so the aggregator's "if normalized != traffic.ExpiryTime" check tripped and the subscription header expire= dropped back to 0 — the exact bug the helper was meant to fix, just one client later. Take nowMs as a parameter; each of GetSubs / GetClash / GetConfig captures one timestamp per request and reuses it. 2) Guarding Flow against empty incoming values in SyncInbound prevented a user from ever clearing a VLESS flow via the panel. FlowOverride on client_inbounds is the per-inbound mechanism that already preserves flow correctly across protocols, so the guard on the shared clients.flow column is the wrong place. Drop the Flow guard, keep the rest (UUID/Password/Auth/Security/ Reverse — none of which have a per-inbound override column). Adds a regression test that asserts clearing flow on the owning inbound makes ListForInbound return flow="". The existing cross-protocol test is rewritten to assert on the user-visible behavior (ListForInbound flow) instead of the shared clients.flow column.
This commit is contained in:
parent
934f9bc230
commit
0ae57881d5
6 changed files with 81 additions and 24 deletions
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
|
@ -63,12 +64,13 @@ func (s *SubClashService) GetClash(subId string, host string) (string, string, e
|
||||||
return "", "", nil
|
return "", "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
for index, clientTraffic := range clientTraffics {
|
for index, clientTraffic := range clientTraffics {
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
traffic.Up = clientTraffic.Up
|
traffic.Up = clientTraffic.Up
|
||||||
traffic.Down = clientTraffic.Down
|
traffic.Down = clientTraffic.Down
|
||||||
traffic.Total = clientTraffic.Total
|
traffic.Total = clientTraffic.Total
|
||||||
traffic.ExpiryTime = subscriptionExpiryFromClient(clientTraffic.ExpiryTime)
|
traffic.ExpiryTime = subscriptionExpiryFromClient(now, clientTraffic.ExpiryTime)
|
||||||
} else {
|
} else {
|
||||||
traffic.Up += clientTraffic.Up
|
traffic.Up += clientTraffic.Up
|
||||||
traffic.Down += clientTraffic.Down
|
traffic.Down += clientTraffic.Down
|
||||||
|
|
@ -77,7 +79,7 @@ func (s *SubClashService) GetClash(subId string, host string) (string, string, e
|
||||||
} else {
|
} else {
|
||||||
traffic.Total += clientTraffic.Total
|
traffic.Total += clientTraffic.Total
|
||||||
}
|
}
|
||||||
normalized := subscriptionExpiryFromClient(clientTraffic.ExpiryTime)
|
normalized := subscriptionExpiryFromClient(now, clientTraffic.ExpiryTime)
|
||||||
if normalized != traffic.ExpiryTime {
|
if normalized != traffic.ExpiryTime {
|
||||||
traffic.ExpiryTime = 0
|
traffic.ExpiryTime = 0
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||||
"github.com/mhsanaei/3x-ui/v3/logger"
|
"github.com/mhsanaei/3x-ui/v3/logger"
|
||||||
|
|
@ -125,12 +126,13 @@ func (s *SubJsonService) GetJson(subId string, host string) (string, string, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare statistics
|
// Prepare statistics
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
for index, clientTraffic := range clientTraffics {
|
for index, clientTraffic := range clientTraffics {
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
traffic.Up = clientTraffic.Up
|
traffic.Up = clientTraffic.Up
|
||||||
traffic.Down = clientTraffic.Down
|
traffic.Down = clientTraffic.Down
|
||||||
traffic.Total = clientTraffic.Total
|
traffic.Total = clientTraffic.Total
|
||||||
traffic.ExpiryTime = subscriptionExpiryFromClient(clientTraffic.ExpiryTime)
|
traffic.ExpiryTime = subscriptionExpiryFromClient(now, clientTraffic.ExpiryTime)
|
||||||
} else {
|
} else {
|
||||||
traffic.Up += clientTraffic.Up
|
traffic.Up += clientTraffic.Up
|
||||||
traffic.Down += clientTraffic.Down
|
traffic.Down += clientTraffic.Down
|
||||||
|
|
@ -139,7 +141,7 @@ func (s *SubJsonService) GetJson(subId string, host string) (string, string, err
|
||||||
} else {
|
} else {
|
||||||
traffic.Total += clientTraffic.Total
|
traffic.Total += clientTraffic.Total
|
||||||
}
|
}
|
||||||
normalized := subscriptionExpiryFromClient(clientTraffic.ExpiryTime)
|
normalized := subscriptionExpiryFromClient(now, clientTraffic.ExpiryTime)
|
||||||
if normalized != traffic.ExpiryTime {
|
if normalized != traffic.ExpiryTime {
|
||||||
traffic.ExpiryTime = 0
|
traffic.ExpiryTime = 0
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -108,13 +108,13 @@ func (s *SubService) GetSubs(subId string, host string) ([]string, int64, xray.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare statistics
|
now := time.Now().UnixMilli()
|
||||||
for index, clientTraffic := range clientTraffics {
|
for index, clientTraffic := range clientTraffics {
|
||||||
if index == 0 {
|
if index == 0 {
|
||||||
traffic.Up = clientTraffic.Up
|
traffic.Up = clientTraffic.Up
|
||||||
traffic.Down = clientTraffic.Down
|
traffic.Down = clientTraffic.Down
|
||||||
traffic.Total = clientTraffic.Total
|
traffic.Total = clientTraffic.Total
|
||||||
traffic.ExpiryTime = subscriptionExpiryFromClient(clientTraffic.ExpiryTime)
|
traffic.ExpiryTime = subscriptionExpiryFromClient(now, clientTraffic.ExpiryTime)
|
||||||
} else {
|
} else {
|
||||||
traffic.Up += clientTraffic.Up
|
traffic.Up += clientTraffic.Up
|
||||||
traffic.Down += clientTraffic.Down
|
traffic.Down += clientTraffic.Down
|
||||||
|
|
@ -123,7 +123,7 @@ func (s *SubService) GetSubs(subId string, host string) ([]string, int64, xray.C
|
||||||
} else {
|
} else {
|
||||||
traffic.Total += clientTraffic.Total
|
traffic.Total += clientTraffic.Total
|
||||||
}
|
}
|
||||||
normalized := subscriptionExpiryFromClient(clientTraffic.ExpiryTime)
|
normalized := subscriptionExpiryFromClient(now, clientTraffic.ExpiryTime)
|
||||||
if normalized != traffic.ExpiryTime {
|
if normalized != traffic.ExpiryTime {
|
||||||
traffic.ExpiryTime = 0
|
traffic.ExpiryTime = 0
|
||||||
}
|
}
|
||||||
|
|
@ -133,12 +133,12 @@ func (s *SubService) GetSubs(subId string, host string) ([]string, int64, xray.C
|
||||||
return result, lastOnline, traffic, nil
|
return result, lastOnline, traffic, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func subscriptionExpiryFromClient(expiryTime int64) int64 {
|
func subscriptionExpiryFromClient(nowMs, expiryTime int64) int64 {
|
||||||
if expiryTime > 0 {
|
if expiryTime > 0 {
|
||||||
return expiryTime
|
return expiryTime
|
||||||
}
|
}
|
||||||
if expiryTime < 0 {
|
if expiryTime < 0 {
|
||||||
return time.Now().UnixMilli() + (-expiryTime)
|
return nowMs + (-expiryTime)
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,24 +5,24 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSubscriptionExpiryFromClient(t *testing.T) {
|
func TestSubscriptionExpiryFromClient(t *testing.T) {
|
||||||
if got := subscriptionExpiryFromClient(0); got != 0 {
|
const now = int64(1_700_000_000_000)
|
||||||
|
const oneDayMs = int64(86_400_000)
|
||||||
|
if got := subscriptionExpiryFromClient(now, 0); got != 0 {
|
||||||
t.Fatalf("zero expiry should stay zero, got %d", got)
|
t.Fatalf("zero expiry should stay zero, got %d", got)
|
||||||
}
|
}
|
||||||
if got := subscriptionExpiryFromClient(1_700_000_000_000); got != 1_700_000_000_000 {
|
if got := subscriptionExpiryFromClient(now, 1_700_000_000_000); got != 1_700_000_000_000 {
|
||||||
t.Fatalf("positive expiry should pass through, got %d", got)
|
t.Fatalf("positive expiry should pass through, got %d", got)
|
||||||
}
|
}
|
||||||
const oneDayMs = int64(86_400_000)
|
if got := subscriptionExpiryFromClient(now, -oneDayMs); got != now+oneDayMs {
|
||||||
before := time.Now().UnixMilli()
|
t.Fatalf("delayed-start expiry should be now+|value|, got %d, want %d", got, now+oneDayMs)
|
||||||
got := subscriptionExpiryFromClient(-oneDayMs)
|
}
|
||||||
after := time.Now().UnixMilli()
|
if a, b := subscriptionExpiryFromClient(now, -oneDayMs), subscriptionExpiryFromClient(now, -oneDayMs); a != b {
|
||||||
if got < before+oneDayMs || got > after+oneDayMs {
|
t.Fatalf("same now+value should be deterministic across calls, got %d vs %d (#4545 review)", a, b)
|
||||||
t.Fatalf("delayed-start expiry should land ~1 day from now, got %d (window %d..%d)", got, before+oneDayMs, after+oneDayMs)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -222,9 +222,7 @@ func (s *ClientService) SyncInbound(tx *gorm.DB, inboundId int, clients []model.
|
||||||
if incoming.Auth != "" {
|
if incoming.Auth != "" {
|
||||||
row.Auth = incoming.Auth
|
row.Auth = incoming.Auth
|
||||||
}
|
}
|
||||||
if incoming.Flow != "" {
|
row.Flow = incoming.Flow
|
||||||
row.Flow = incoming.Flow
|
|
||||||
}
|
|
||||||
if incoming.Security != "" {
|
if incoming.Security != "" {
|
||||||
row.Security = incoming.Security
|
row.Security = incoming.Security
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,9 @@ func TestSyncInbound_PreservesCredentialsAcrossProtocols(t *testing.T) {
|
||||||
const sharedEmail = "shared@example.com"
|
const sharedEmail = "shared@example.com"
|
||||||
const wantUUID = "ce8d33df-3a64-4f10-8f9b-91c3a8e0c001"
|
const wantUUID = "ce8d33df-3a64-4f10-8f9b-91c3a8e0c001"
|
||||||
const wantAuth = "h2-auth-token"
|
const wantAuth = "h2-auth-token"
|
||||||
|
const wantFlow = "xtls-rprx-vision"
|
||||||
|
|
||||||
vlessClient := model.Client{Email: sharedEmail, ID: wantUUID, Enable: true, Flow: "xtls-rprx-vision"}
|
vlessClient := model.Client{Email: sharedEmail, ID: wantUUID, Enable: true, Flow: wantFlow}
|
||||||
if err := svc.SyncInbound(nil, vlessInbound.Id, []model.Client{vlessClient}); err != nil {
|
if err := svc.SyncInbound(nil, vlessInbound.Id, []model.Client{vlessClient}); err != nil {
|
||||||
t.Fatalf("vless SyncInbound: %v", err)
|
t.Fatalf("vless SyncInbound: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -52,7 +53,61 @@ func TestSyncInbound_PreservesCredentialsAcrossProtocols(t *testing.T) {
|
||||||
if row.Auth != wantAuth {
|
if row.Auth != wantAuth {
|
||||||
t.Errorf("Auth not persisted: got %q, want %q", row.Auth, wantAuth)
|
t.Errorf("Auth not persisted: got %q, want %q", row.Auth, wantAuth)
|
||||||
}
|
}
|
||||||
if row.Flow != "xtls-rprx-vision" {
|
|
||||||
t.Errorf("Flow was clobbered by Hysteria sync: got %q, want xtls-rprx-vision", row.Flow)
|
vlessList, err := svc.ListForInbound(nil, vlessInbound.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListForInbound(vless): %v", err)
|
||||||
|
}
|
||||||
|
if len(vlessList) != 1 || vlessList[0].Flow != wantFlow {
|
||||||
|
t.Errorf("VLESS inbound should still report flow=%q via FlowOverride, got %#v", wantFlow, vlessList)
|
||||||
|
}
|
||||||
|
|
||||||
|
hysteriaList, err := svc.ListForInbound(nil, hysteriaInbound.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListForInbound(hysteria): %v", err)
|
||||||
|
}
|
||||||
|
if len(hysteriaList) != 1 || hysteriaList[0].Flow != "" {
|
||||||
|
t.Errorf("Hysteria inbound should report empty flow, got %#v", hysteriaList)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncInbound_AllowsClearingFlow(t *testing.T) {
|
||||||
|
dbDir := t.TempDir()
|
||||||
|
t.Setenv("XUI_DB_FOLDER", dbDir)
|
||||||
|
if err := database.InitDB(filepath.Join(dbDir, "3x-ui.db")); err != nil {
|
||||||
|
t.Fatalf("InitDB: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = database.CloseDB() })
|
||||||
|
|
||||||
|
db := database.GetDB()
|
||||||
|
|
||||||
|
vless := &model.Inbound{Tag: "vless-in", Enable: true, Port: 10003, Protocol: model.VLESS}
|
||||||
|
if err := db.Create(vless).Error; err != nil {
|
||||||
|
t.Fatalf("create vless inbound: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := ClientService{}
|
||||||
|
const email = "alice@example.com"
|
||||||
|
const uid = "ce8d33df-3a64-4f10-8f9b-91c3a8e0c002"
|
||||||
|
|
||||||
|
withFlow := model.Client{Email: email, ID: uid, Enable: true, Flow: "xtls-rprx-vision"}
|
||||||
|
if err := svc.SyncInbound(nil, vless.Id, []model.Client{withFlow}); err != nil {
|
||||||
|
t.Fatalf("vless SyncInbound (set flow): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleared := model.Client{Email: email, ID: uid, Enable: true, Flow: ""}
|
||||||
|
if err := svc.SyncInbound(nil, vless.Id, []model.Client{cleared}); err != nil {
|
||||||
|
t.Fatalf("vless SyncInbound (clear flow): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := svc.ListForInbound(nil, vless.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListForInbound: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 1 {
|
||||||
|
t.Fatalf("expected 1 client, got %d", len(list))
|
||||||
|
}
|
||||||
|
if list[0].Flow != "" {
|
||||||
|
t.Errorf("flow should be clearable on the owning inbound, got %q (Copilot review on #4545)", list[0].Flow)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue