Files
2025-07-26 17:18:41 -07:00

224 lines
5.6 KiB
Go

package bothandler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"slices"
"strings"
"git.pengzhan.dev/aimaren/internal/storage"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
)
type BotMode int
const (
UnsafeMode BotMode = iota
WhiteListMode
BlackListMode
)
// BotHandler 处理 Telegram 请求
type BotHandler struct {
bot *tgbotapi.BotAPI
store storage.Storer // 用你的接口而不是具体类型
mode BotMode
list []int64
}
func NewBotHandler(bot *tgbotapi.BotAPI, store storage.Storer, mode BotMode, list []int64) *BotHandler {
return &BotHandler{
bot: bot,
store: store,
mode: mode,
list: list,
}
}
func (h *BotHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "cannot read body", http.StatusBadRequest)
return
}
var update tgbotapi.Update
if err := json.Unmarshal(body, &update); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
if update.Message == nil {
w.WriteHeader(http.StatusOK)
return
}
text := strings.TrimSpace(update.Message.Text)
chatID := update.Message.Chat.ID
strChatId := fmt.Sprintf("%d", chatID)
// 模式检查
switch h.mode {
case UnsafeMode:
// nothing
case WhiteListMode:
if !slices.Contains(h.list, chatID) {
log.Printf("❌ Chat ID %d not in whitelist, ignore: %s", chatID, text)
w.WriteHeader(http.StatusOK)
return
}
case BlackListMode:
if slices.Contains(h.list, chatID) {
log.Printf("❌ Chat ID %d in blacklist, ignore: %s", chatID, text)
w.WriteHeader(http.StatusOK)
return
}
}
ctx := context.Background()
userState, err := h.store.FetchUserState(ctx)
if err != nil {
log.Printf("❌ FetchUserState error: %v", err)
http.Error(w, "fetch user state error", http.StatusInternalServerError)
return
}
if userState.Users == nil {
userState.Users = make(map[string]storage.ChatState)
}
cs, ok := userState.Users[strChatId]
// 判断是否为命令
if update.Message.IsCommand() {
switch update.Message.Command() {
case "start":
h.send(chatID, helpText())
case "hello":
h.send(chatID, escapeMarkdownV2(fmt.Sprintf("👋 你好! 你的 chat id 是 %d", chatID)))
case "ride":
if ok && cs.Registered {
h.send(chatID, "✅ 你已经注册过了!")
break
}
if ok && cs.CurrentOp == "ride_waiting" {
h.send(chatID, "⏳ 你已经在等待输入邀请码,请直接发送邀请码。")
break
}
userState.Users[strChatId] = storage.ChatState{
ChatID: chatID,
CurrentOp: "ride_waiting",
Context: "",
}
if err := h.store.UpdateUserState(ctx, userState); err != nil {
log.Printf("❌ UpdateUserState error: %v", err)
}
h.send(chatID, "📩 请输入你的邀请码:")
case "tea":
h.send(chatID, "☕ 本项目不盈利,但为了周转,需要你的支持。")
default:
h.send(chatID, helpText())
}
w.WriteHeader(http.StatusOK)
return
}
// 非命令消息:检查是否在等待邀请码
if ok && cs.CurrentOp == "ride_waiting" {
code := text
if slices.Contains(userState.FreeCode, code) {
userState.Users[strChatId] = storage.ChatState{
ChatID: chatID,
Registered: true,
InviteCode: code,
}
userState.FreeCode = slices.Delete(userState.FreeCode, slices.Index(userState.FreeCode, code), slices.Index(userState.FreeCode, code)+1)
if err := h.store.UpdateUserState(ctx, userState); err != nil {
log.Printf("❌ UpdateUserState error: %v", err)
}
log.Printf("✅ Chat ID %d registered with invite code: %s", chatID, code)
h.store.AddChatID(ctx, chatID)
log.Printf("💾 Chat ID %d added to app state.", chatID)
h.send(chatID, "✅ 邀请码校验成功,已注册!")
h.send(chatID, h.currentStockText(ctx))
} else {
userState.Users[strChatId] = storage.ChatState{
ChatID: chatID,
}
if err := h.store.UpdateUserState(ctx, userState); err != nil {
log.Printf("❌ UpdateUserState error: %v", err)
}
h.send(chatID, "❌ 邀请码错误。")
}
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
}
// send 封装消息发送
func (h *BotHandler) send(chatID int64, text string) {
msg := tgbotapi.NewMessage(chatID, text)
msg.ParseMode = "MarkdownV2"
if _, err := h.bot.Send(msg); err != nil {
log.Printf("❌ send message error: %v", err)
}
}
func helpText() string {
return `🤖 可用命令:
/start 查看帮助
/hello 查看 chat id
/ride 注册
/tea 支持作者`
}
func (h *BotHandler) currentStockText(ctx context.Context) string {
as, err := h.store.FetchAppState(ctx)
if err != nil {
log.Printf("❌ FetchAppState error: %v", err)
return escapeMarkdownV2("❌ 无法获取当前库存状态")
}
if len(as.Bags) == 0 {
return escapeMarkdownV2("当前没有可用的库存。")
}
var sb strings.Builder
sb.WriteString(escapeMarkdownV2("当前缓存库存:") + "\n")
for _, bag := range as.Bags {
if bag.Availability {
// MarkdownV2 link: [escapedName](escapedURL)
name := escapeMarkdownV2(bag.Name)
url := escapeMarkdownV2(bag.URL)
sb.WriteString(fmt.Sprintf("[%s](%s)\n", name, url))
}
}
return sb.String()
}
func escapeMarkdownV2(text string) string {
// List of characters that must be escaped in MarkdownV2
specialChars := []string{"_", "*", "[", "]", "(", ")", "~", "`", ">", "#", "+", "-", "=", "|", "{", "}", ".", "!"}
escaped := text
for _, ch := range specialChars {
escaped = strings.ReplaceAll(escaped, ch, "\\"+ch)
}
return escaped
}