diff --git a/web/global/hashStorage.go b/web/global/hashStorage.go index 099a54ed..3d5a6a0d 100644 --- a/web/global/hashStorage.go +++ b/web/global/hashStorage.go @@ -3,8 +3,10 @@ package global import ( "crypto/md5" "encoding/hex" + "regexp" "sync" "time" + "x-ui/util/common" ) type HashEntry struct { @@ -59,15 +61,23 @@ func (h *HashStorage) saveValue(query string) string { return md5HashString } -func (h *HashStorage) GetValue(hash string) string { +func (h *HashStorage) GetValue(hash string) (string, error) { h.RLock() defer h.RUnlock() entry, exists := h.Data[hash] if !exists { - return hash + if h.isMD5(hash) { + return "", common.NewError("hash not found in storage!") + } + return hash, nil } - return entry.Value + return entry.Value, nil +} + +func (h *HashStorage) isMD5(hash string) bool { + match, _ := regexp.MatchString("^[a-f0-9]{32}$", hash) + return match } func (h *HashStorage) RemoveExpiredHashes() { diff --git a/web/service/tgbot.go b/web/service/tgbot.go index 8777cbbd..fa9ccb1d 100644 --- a/web/service/tgbot.go +++ b/web/service/tgbot.go @@ -61,8 +61,9 @@ func (t *Tgbot) Start(i18nFS embed.FS) error { return err } - // init hash storage - t.hashStorage = global.NewHashStorage(5*time.Minute, false) + // init hash storage => store callback queries + // NOTE: it only save the query if its length is more than 64 chars. + t.hashStorage = global.NewHashStorage(20*time.Minute, false) tgBottoken, err := t.settingService.GetTgBotToken() if err != nil || tgBottoken == "" { @@ -199,8 +200,12 @@ func (t *Tgbot) asnwerCallback(callbackQuery *telego.CallbackQuery, isAdmin bool chatId := callbackQuery.Message.Chat.ID if isAdmin { - // get query from hash storage (if the query was <= 64 chars hash storage dont save the hash and return data itself) - decodedQuery := t.hashStorage.GetValue(callbackQuery.Data) + // get query from hash storage + decodedQuery, err := t.hashStorage.GetValue(callbackQuery.Data) + if err != nil { + t.SendMsgToTgbot(chatId, "Query not found! Please use the command again!") + return + } dataArray := strings.Split(decodedQuery, " ") if len(dataArray) >= 2 && len(dataArray[1]) > 0 {