224 lines
5.6 KiB
Go
224 lines
5.6 KiB
Go
package bothandler
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
|
|
"git.pengzhan.dev/aimaren/internal/storage"
|
|
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
|
)
|
|
|
|
type BotMode int
|
|
|
|
const (
|
|
UnsafeMode BotMode = iota
|
|
WhiteListMode
|
|
BlackListMode
|
|
)
|
|
|
|
// BotHandler 处理 Telegram 请求
|
|
type BotHandler struct {
|
|
bot *tgbotapi.BotAPI
|
|
store storage.Storer // 用你的接口而不是具体类型
|
|
mode BotMode
|
|
list []int64
|
|
}
|
|
|
|
func NewBotHandler(bot *tgbotapi.BotAPI, store storage.Storer, mode BotMode, list []int64) *BotHandler {
|
|
return &BotHandler{
|
|
bot: bot,
|
|
store: store,
|
|
mode: mode,
|
|
list: list,
|
|
}
|
|
}
|
|
|
|
func (h *BotHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "cannot read body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var update tgbotapi.Update
|
|
if err := json.Unmarshal(body, &update); err != nil {
|
|
http.Error(w, "bad json", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if update.Message == nil {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
text := strings.TrimSpace(update.Message.Text)
|
|
chatID := update.Message.Chat.ID
|
|
strChatId := fmt.Sprintf("%d", chatID)
|
|
|
|
// 模式检查
|
|
switch h.mode {
|
|
case UnsafeMode:
|
|
// nothing
|
|
case WhiteListMode:
|
|
if !slices.Contains(h.list, chatID) {
|
|
log.Printf("❌ Chat ID %d not in whitelist, ignore: %s", chatID, text)
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
case BlackListMode:
|
|
if slices.Contains(h.list, chatID) {
|
|
log.Printf("❌ Chat ID %d in blacklist, ignore: %s", chatID, text)
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
}
|
|
ctx := context.Background()
|
|
userState, err := h.store.FetchUserState(ctx)
|
|
if err != nil {
|
|
log.Printf("❌ FetchUserState error: %v", err)
|
|
http.Error(w, "fetch user state error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if userState.Users == nil {
|
|
userState.Users = make(map[string]storage.ChatState)
|
|
}
|
|
|
|
cs, ok := userState.Users[strChatId]
|
|
|
|
// 判断是否为命令
|
|
if update.Message.IsCommand() {
|
|
switch update.Message.Command() {
|
|
|
|
case "start":
|
|
h.send(chatID, helpText())
|
|
case "hello":
|
|
h.send(chatID, escapeMarkdownV2(fmt.Sprintf("👋 你好! 你的 chat id 是 %d", chatID)))
|
|
case "ride":
|
|
if ok && cs.Registered {
|
|
h.send(chatID, "✅ 你已经注册过了!")
|
|
break
|
|
}
|
|
if ok && cs.CurrentOp == "ride_waiting" {
|
|
h.send(chatID, "⏳ 你已经在等待输入邀请码,请直接发送邀请码。")
|
|
break
|
|
}
|
|
userState.Users[strChatId] = storage.ChatState{
|
|
ChatID: chatID,
|
|
CurrentOp: "ride_waiting",
|
|
Context: "",
|
|
}
|
|
if err := h.store.UpdateUserState(ctx, userState); err != nil {
|
|
log.Printf("❌ UpdateUserState error: %v", err)
|
|
}
|
|
h.send(chatID, "📩 请输入你的邀请码:")
|
|
|
|
case "tea":
|
|
h.send(chatID, "☕ 本项目不盈利,但为了周转,需要你的支持。")
|
|
|
|
default:
|
|
h.send(chatID, helpText())
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
// 非命令消息:检查是否在等待邀请码
|
|
if ok && cs.CurrentOp == "ride_waiting" {
|
|
code := text
|
|
if slices.Contains(userState.FreeCode, code) {
|
|
userState.Users[strChatId] = storage.ChatState{
|
|
ChatID: chatID,
|
|
Registered: true,
|
|
InviteCode: code,
|
|
}
|
|
userState.FreeCode = slices.Delete(userState.FreeCode, slices.Index(userState.FreeCode, code), slices.Index(userState.FreeCode, code)+1)
|
|
if err := h.store.UpdateUserState(ctx, userState); err != nil {
|
|
log.Printf("❌ UpdateUserState error: %v", err)
|
|
}
|
|
log.Printf("✅ Chat ID %d registered with invite code: %s", chatID, code)
|
|
h.store.AddChatID(ctx, chatID)
|
|
log.Printf("💾 Chat ID %d added to app state.", chatID)
|
|
h.send(chatID, "✅ 邀请码校验成功,已注册!")
|
|
h.send(chatID, h.currentStockText(ctx))
|
|
} else {
|
|
userState.Users[strChatId] = storage.ChatState{
|
|
ChatID: chatID,
|
|
}
|
|
if err := h.store.UpdateUserState(ctx, userState); err != nil {
|
|
log.Printf("❌ UpdateUserState error: %v", err)
|
|
}
|
|
h.send(chatID, "❌ 邀请码错误。")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
// send 封装消息发送
|
|
func (h *BotHandler) send(chatID int64, text string) {
|
|
msg := tgbotapi.NewMessage(chatID, text)
|
|
msg.ParseMode = "MarkdownV2"
|
|
if _, err := h.bot.Send(msg); err != nil {
|
|
log.Printf("❌ send message error: %v", err)
|
|
}
|
|
}
|
|
|
|
func helpText() string {
|
|
return `🤖 可用命令:
|
|
/start 查看帮助
|
|
/hello 查看 chat id
|
|
/ride 注册
|
|
/tea 支持作者`
|
|
}
|
|
|
|
func (h *BotHandler) currentStockText(ctx context.Context) string {
|
|
as, err := h.store.FetchAppState(ctx)
|
|
if err != nil {
|
|
log.Printf("❌ FetchAppState error: %v", err)
|
|
return escapeMarkdownV2("❌ 无法获取当前库存状态")
|
|
}
|
|
|
|
if len(as.Bags) == 0 {
|
|
return escapeMarkdownV2("当前没有可用的库存。")
|
|
}
|
|
|
|
var sb strings.Builder
|
|
sb.WriteString(escapeMarkdownV2("当前缓存库存:") + "\n")
|
|
|
|
for _, bag := range as.Bags {
|
|
if bag.Availability {
|
|
// MarkdownV2 link: [escapedName](escapedURL)
|
|
name := escapeMarkdownV2(bag.Name)
|
|
url := escapeMarkdownV2(bag.URL)
|
|
sb.WriteString(fmt.Sprintf("[%s](%s)\n", name, url))
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
func escapeMarkdownV2(text string) string {
|
|
// List of characters that must be escaped in MarkdownV2
|
|
specialChars := []string{"_", "*", "[", "]", "(", ")", "~", "`", ">", "#", "+", "-", "=", "|", "{", "}", ".", "!"}
|
|
|
|
escaped := text
|
|
for _, ch := range specialChars {
|
|
escaped = strings.ReplaceAll(escaped, ch, "\\"+ch)
|
|
}
|
|
return escaped
|
|
}
|