Files
2025-07-26 05:58:59 +00:00

133 lines
3.4 KiB
Go

// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package firestore
import (
"fmt"
pb "cloud.google.com/go/firestore/apiv1/firestorepb"
)
const (
typeKey = "__type__"
typeValVector = "__vector__"
valueKey = "value"
)
// Vector64 is an embedding vector of float64s.
type Vector64 []float64
// Vector32 is an embedding vector of float32s.
type Vector32 []float32
// vectorToProtoValue returns a Firestore [pb.Value] representing the Vector.
func vectorToProtoValue[T float32 | float64](v []T) *pb.Value {
if v == nil {
return nullValue
}
pbVals := make([]*pb.Value, len(v))
for i, val := range v {
pbVals[i] = floatToProtoValue(float64(val))
}
return &pb.Value{
ValueType: &pb.Value_MapValue{
MapValue: &pb.MapValue{
Fields: map[string]*pb.Value{
typeKey: stringToProtoValue(typeValVector),
valueKey: {
ValueType: &pb.Value_ArrayValue{
ArrayValue: &pb.ArrayValue{Values: pbVals},
},
},
},
},
},
}
}
func vector32FromProtoValue(v *pb.Value) (Vector32, error) {
return vectorFromProtoValue[float32](v)
}
func vector64FromProtoValue(v *pb.Value) (Vector64, error) {
return vectorFromProtoValue[float64](v)
}
func vectorFromProtoValue[T float32 | float64](v *pb.Value) ([]T, error) {
pbArrVals, err := pbValToVectorVals(v)
if err != nil {
return nil, err
}
floats := make([]T, len(pbArrVals))
for i, fval := range pbArrVals {
dv, ok := fval.ValueType.(*pb.Value_DoubleValue)
if !ok {
return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType)
}
floats[i] = T(dv.DoubleValue)
}
return floats, nil
}
func pbValToVectorVals(v *pb.Value) ([]*pb.Value, error) {
/*
Vector is stored as:
{
"__type__": "__vector__",
"value": []float64{},
}
*/
if v == nil {
return nil, nil
}
pbMap, ok := v.ValueType.(*pb.Value_MapValue)
if !ok {
return nil, fmt.Errorf("firestore: cannot convert %v to *pb.Value_MapValue", v.ValueType)
}
m := pbMap.MapValue.Fields
var typeVal string
typeVal, err := stringFromProtoValue(m[typeKey])
if err != nil {
return nil, err
}
if typeVal != typeValVector {
return nil, fmt.Errorf("firestore: value of %v : %v is not %v", typeKey, typeVal, typeValVector)
}
pbVal, ok := m[valueKey]
if !ok {
return nil, fmt.Errorf("firestore: %v not present in %v", valueKey, m)
}
pbArr, ok := pbVal.ValueType.(*pb.Value_ArrayValue)
if !ok {
return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_ArrayValue", pbVal.ValueType)
}
return pbArr.ArrayValue.Values, nil
}
func stringFromProtoValue(v *pb.Value) (string, error) {
if v == nil {
return "", fmt.Errorf("firestore: failed to convert %v to string", v)
}
sv, ok := v.ValueType.(*pb.Value_StringValue)
if !ok {
return "", fmt.Errorf("firestore: failed to convert %v to *pb.Value_StringValue", v.ValueType)
}
return sv.StringValue, nil
}