package service import ( "encoding/json" "errors" "strings" "time" "github.com/google/uuid" "github.com/mhsanaei/3x-ui/v3/database" "github.com/mhsanaei/3x-ui/v3/database/model" "github.com/mhsanaei/3x-ui/v3/util/common" "github.com/mhsanaei/3x-ui/v3/xray" "gorm.io/gorm" ) type ClientWithAttachments struct { model.ClientRecord InboundIds []int `json:"inboundIds"` Traffic *xray.ClientTraffic `json:"traffic,omitempty"` } func clientKeyForProtocol(p model.Protocol, rec *model.ClientRecord) string { if rec == nil { return "" } switch p { case model.Trojan: return rec.Password case model.Shadowsocks: return rec.Email case model.Hysteria, model.Hysteria2: return rec.Auth default: return rec.UUID } } type ClientService struct{} func (s *ClientService) SyncInbound(tx *gorm.DB, inboundId int, clients []model.Client) error { if tx == nil { tx = database.GetDB() } if err := tx.Where("inbound_id = ?", inboundId).Delete(&model.ClientInbound{}).Error; err != nil { return err } for i := range clients { c := clients[i] email := strings.TrimSpace(c.Email) if email == "" { continue } incoming := c.ToRecord() row := &model.ClientRecord{} err := tx.Where("email = ?", email).First(row).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if errors.Is(err, gorm.ErrRecordNotFound) { if err := tx.Create(incoming).Error; err != nil { return err } row = incoming } else { row.UUID = incoming.UUID row.Password = incoming.Password row.Auth = incoming.Auth row.Flow = incoming.Flow row.Security = incoming.Security row.Reverse = incoming.Reverse row.SubID = incoming.SubID row.LimitIP = incoming.LimitIP row.TotalGB = incoming.TotalGB row.ExpiryTime = incoming.ExpiryTime row.Enable = incoming.Enable row.TgID = incoming.TgID row.Comment = incoming.Comment row.Reset = incoming.Reset if incoming.CreatedAt > 0 && (row.CreatedAt == 0 || incoming.CreatedAt < row.CreatedAt) { row.CreatedAt = incoming.CreatedAt } if incoming.UpdatedAt > row.UpdatedAt { row.UpdatedAt = incoming.UpdatedAt } if err := tx.Save(row).Error; err != nil { return err } } link := model.ClientInbound{ ClientId: row.Id, InboundId: inboundId, FlowOverride: c.Flow, } if err := tx.Create(&link).Error; err != nil { return err } } return nil } func (s *ClientService) DetachInbound(tx *gorm.DB, inboundId int) error { if tx == nil { tx = database.GetDB() } return tx.Where("inbound_id = ?", inboundId).Delete(&model.ClientInbound{}).Error } func (s *ClientService) ListForInbound(tx *gorm.DB, inboundId int) ([]model.Client, error) { if tx == nil { tx = database.GetDB() } type joinedRow struct { model.ClientRecord FlowOverride string } var rows []joinedRow err := tx.Table("clients"). Select("clients.*, client_inbounds.flow_override AS flow_override"). Joins("JOIN client_inbounds ON client_inbounds.client_id = clients.id"). Where("client_inbounds.inbound_id = ?", inboundId). Order("clients.id ASC"). Find(&rows).Error if err != nil { return nil, err } out := make([]model.Client, 0, len(rows)) for i := range rows { c := rows[i].ToClient() if rows[i].FlowOverride != "" { c.Flow = rows[i].FlowOverride } out = append(out, *c) } return out, nil } func (s *ClientService) GetRecordByEmail(tx *gorm.DB, email string) (*model.ClientRecord, error) { if tx == nil { tx = database.GetDB() } row := &model.ClientRecord{} err := tx.Where("email = ?", email).First(row).Error if err != nil { return nil, err } return row, nil } func (s *ClientService) GetInboundIdsForEmail(tx *gorm.DB, email string) ([]int, error) { if tx == nil { tx = database.GetDB() } var ids []int err := tx.Table("client_inbounds"). Select("client_inbounds.inbound_id"). Joins("JOIN clients ON clients.id = client_inbounds.client_id"). Where("clients.email = ?", email). Scan(&ids).Error if err != nil { return nil, err } return ids, nil } func (s *ClientService) GetByID(id int) (*model.ClientRecord, error) { row := &model.ClientRecord{} if err := database.GetDB().Where("id = ?", id).First(row).Error; err != nil { return nil, err } return row, nil } func (s *ClientService) GetInboundIdsForRecord(id int) ([]int, error) { var ids []int err := database.GetDB().Table("client_inbounds"). Where("client_id = ?", id). Order("inbound_id ASC"). Pluck("inbound_id", &ids).Error if err != nil { return nil, err } return ids, nil } func (s *ClientService) List() ([]ClientWithAttachments, error) { db := database.GetDB() var rows []model.ClientRecord if err := db.Order("id ASC").Find(&rows).Error; err != nil { return nil, err } if len(rows) == 0 { return []ClientWithAttachments{}, nil } clientIds := make([]int, 0, len(rows)) emails := make([]string, 0, len(rows)) for i := range rows { clientIds = append(clientIds, rows[i].Id) if rows[i].Email != "" { emails = append(emails, rows[i].Email) } } var links []model.ClientInbound if err := db.Where("client_id IN ?", clientIds).Find(&links).Error; err != nil { return nil, err } attachments := make(map[int][]int, len(rows)) for _, l := range links { attachments[l.ClientId] = append(attachments[l.ClientId], l.InboundId) } trafficByEmail := make(map[string]*xray.ClientTraffic, len(emails)) if len(emails) > 0 { var stats []xray.ClientTraffic if err := db.Where("email IN ?", emails).Find(&stats).Error; err != nil { return nil, err } for i := range stats { trafficByEmail[stats[i].Email] = &stats[i] } } out := make([]ClientWithAttachments, 0, len(rows)) for i := range rows { out = append(out, ClientWithAttachments{ ClientRecord: rows[i], InboundIds: attachments[rows[i].Id], Traffic: trafficByEmail[rows[i].Email], }) } return out, nil } type ClientCreatePayload struct { Client model.Client `json:"client"` InboundIds []int `json:"inboundIds"` } func (s *ClientService) Create(inboundSvc *InboundService, payload *ClientCreatePayload) (bool, error) { if payload == nil { return false, common.NewError("empty payload") } client := payload.Client if strings.TrimSpace(client.Email) == "" { return false, common.NewError("client email is required") } if len(payload.InboundIds) == 0 { return false, common.NewError("at least one inbound is required") } if client.SubID == "" { client.SubID = uuid.NewString() } if !client.Enable { client.Enable = true } now := time.Now().UnixMilli() if client.CreatedAt == 0 { client.CreatedAt = now } client.UpdatedAt = now existing := &model.ClientRecord{} err := database.GetDB().Where("email = ?", client.Email).First(existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return false, err } emailTaken := !errors.Is(err, gorm.ErrRecordNotFound) if emailTaken { if existing.SubID == "" || existing.SubID != client.SubID { return false, common.NewError("email already in use:", client.Email) } } needRestart := false for _, ibId := range payload.InboundIds { inbound, getErr := inboundSvc.GetInbound(ibId) if getErr != nil { return needRestart, getErr } if err := s.fillProtocolDefaults(&client, inbound.Protocol); err != nil { return needRestart, err } settingsPayload, mErr := json.Marshal(map[string][]model.Client{"clients": {client}}) if mErr != nil { return needRestart, mErr } nr, addErr := inboundSvc.AddInboundClient(&model.Inbound{ Id: ibId, Settings: string(settingsPayload), }) if addErr != nil { return needRestart, addErr } if nr { needRestart = true } } return needRestart, nil } func (s *ClientService) fillProtocolDefaults(c *model.Client, p model.Protocol) error { switch p { case model.VMESS, model.VLESS: if c.ID == "" { c.ID = uuid.NewString() } case model.Trojan, model.Shadowsocks: if c.Password == "" { c.Password = strings.ReplaceAll(uuid.NewString(), "-", "") } case model.Hysteria, model.Hysteria2: if c.Auth == "" { c.Auth = strings.ReplaceAll(uuid.NewString(), "-", "") } } return nil } func (s *ClientService) Update(inboundSvc *InboundService, id int, updated model.Client) (bool, error) { existing, err := s.GetByID(id) if err != nil { return false, err } inboundIds, err := s.GetInboundIdsForRecord(id) if err != nil { return false, err } if strings.TrimSpace(updated.Email) == "" { return false, common.NewError("client email is required") } if updated.SubID == "" { updated.SubID = existing.SubID } if updated.SubID == "" { updated.SubID = uuid.NewString() } updated.UpdatedAt = time.Now().UnixMilli() if updated.CreatedAt == 0 { updated.CreatedAt = existing.CreatedAt } needRestart := false for _, ibId := range inboundIds { inbound, getErr := inboundSvc.GetInbound(ibId) if getErr != nil { return needRestart, getErr } oldKey := clientKeyForProtocol(inbound.Protocol, existing) if oldKey == "" { continue } if err := s.fillProtocolDefaults(&updated, inbound.Protocol); err != nil { return needRestart, err } settingsPayload, mErr := json.Marshal(map[string][]model.Client{"clients": {updated}}) if mErr != nil { return needRestart, mErr } nr, upErr := inboundSvc.UpdateInboundClient(&model.Inbound{ Id: ibId, Settings: string(settingsPayload), }, oldKey) if upErr != nil { return needRestart, upErr } if nr { needRestart = true } } return needRestart, nil } func (s *ClientService) Delete(inboundSvc *InboundService, id int, keepTraffic bool) (bool, error) { existing, err := s.GetByID(id) if err != nil { return false, err } inboundIds, err := s.GetInboundIdsForRecord(id) if err != nil { return false, err } needRestart := false for _, ibId := range inboundIds { inbound, getErr := inboundSvc.GetInbound(ibId) if getErr != nil { return needRestart, getErr } key := clientKeyForProtocol(inbound.Protocol, existing) if key == "" { continue } nr, delErr := inboundSvc.DelInboundClient(ibId, key) if delErr != nil { return needRestart, delErr } if nr { needRestart = true } } db := database.GetDB() if err := db.Where("client_id = ?", id).Delete(&model.ClientInbound{}).Error; err != nil { return needRestart, err } if !keepTraffic && existing.Email != "" { if err := db.Where("email = ?", existing.Email).Delete(&xray.ClientTraffic{}).Error; err != nil { return needRestart, err } if err := db.Where("client_email = ?", existing.Email).Delete(&model.InboundClientIps{}).Error; err != nil { return needRestart, err } } if err := db.Delete(&model.ClientRecord{}, id).Error; err != nil { return needRestart, err } return needRestart, nil } func (s *ClientService) Attach(inboundSvc *InboundService, id int, inboundIds []int) (bool, error) { existing, err := s.GetByID(id) if err != nil { return false, err } currentIds, err := s.GetInboundIdsForRecord(id) if err != nil { return false, err } have := make(map[int]struct{}, len(currentIds)) for _, x := range currentIds { have[x] = struct{}{} } clientWire := existing.ToClient() clientWire.UpdatedAt = time.Now().UnixMilli() needRestart := false for _, ibId := range inboundIds { if _, attached := have[ibId]; attached { continue } inbound, getErr := inboundSvc.GetInbound(ibId) if getErr != nil { return needRestart, getErr } copyClient := *clientWire if err := s.fillProtocolDefaults(©Client, inbound.Protocol); err != nil { return needRestart, err } settingsPayload, mErr := json.Marshal(map[string][]model.Client{"clients": {copyClient}}) if mErr != nil { return needRestart, mErr } nr, addErr := inboundSvc.AddInboundClient(&model.Inbound{ Id: ibId, Settings: string(settingsPayload), }) if addErr != nil { return needRestart, addErr } if nr { needRestart = true } } return needRestart, nil } func (s *ClientService) Detach(inboundSvc *InboundService, id int, inboundIds []int) (bool, error) { existing, err := s.GetByID(id) if err != nil { return false, err } currentIds, err := s.GetInboundIdsForRecord(id) if err != nil { return false, err } have := make(map[int]struct{}, len(currentIds)) for _, x := range currentIds { have[x] = struct{}{} } needRestart := false for _, ibId := range inboundIds { if _, attached := have[ibId]; !attached { continue } inbound, getErr := inboundSvc.GetInbound(ibId) if getErr != nil { return needRestart, getErr } key := clientKeyForProtocol(inbound.Protocol, existing) if key == "" { continue } nr, delErr := inboundSvc.DelInboundClient(ibId, key) if delErr != nil { return needRestart, delErr } if nr { needRestart = true } } return needRestart, nil }