68 lines
1.8 KiB
Go
68 lines
1.8 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
|
|
"cloud.google.com/go/firestore"
|
|
"git.pengzhan.dev/aimaren/internal/crawler"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
const (
|
|
stateCollection = "hermes_state"
|
|
stateDocument = "main"
|
|
)
|
|
|
|
type FirestoreClient struct {
|
|
client *firestore.Client
|
|
}
|
|
|
|
func NewFirestoreClient(ctx context.Context, projectID string) (*FirestoreClient, error) {
|
|
client, err := firestore.NewClient(ctx, projectID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &FirestoreClient{client: client}, nil
|
|
}
|
|
|
|
// FetchAppState retrieves the entire application state from a single document.
|
|
func (fs *FirestoreClient) FetchAppState(ctx context.Context) (*AppState, error) {
|
|
doc, err := fs.client.Collection(stateCollection).Doc(stateDocument).Get(ctx)
|
|
if err != nil {
|
|
// If the doc doesn't exist, return a new, empty AppState.
|
|
if status.Code(err) == codes.NotFound {
|
|
return &AppState{
|
|
Bags: make(map[string]crawler.Bag),
|
|
ChatIDs: []int64{},
|
|
}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
var state AppState
|
|
if err := doc.DataTo(&state); err != nil {
|
|
return nil, err
|
|
}
|
|
return &state, nil
|
|
}
|
|
|
|
// UpdateAppState writes the entire application state back to the document.
|
|
func (fs *FirestoreClient) UpdateAppState(ctx context.Context, newState *AppState) error {
|
|
_, err := fs.client.Collection(stateCollection).Doc(stateDocument).Set(ctx, newState)
|
|
return err
|
|
}
|
|
|
|
// AddChatID atomically adds a new chat ID to the list in the main state document.
|
|
func (fs *FirestoreClient) AddChatID(ctx context.Context, chatID int64) error {
|
|
docRef := fs.client.Collection(stateCollection).Doc(stateDocument)
|
|
_, err := docRef.Update(ctx, []firestore.Update{
|
|
{Path: "chat_ids", Value: firestore.ArrayUnion(chatID)},
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (fs *FirestoreClient) Close() {
|
|
fs.client.Close()
|
|
}
|