gomog/internal/engine/aggregate_helpers.go

751 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package engine
// aggregate_helpers.go - 聚合辅助函数
// 使用 helpers.go 中的公共辅助函数
import (
"fmt"
"math"
"math/rand"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// 初始化随机种子
func init() {
rand.Seed(time.Now().UnixNano())
}
// concat 字符串连接
func (e *AggregationEngine) concat(operand interface{}, data map[string]interface{}) string {
arr, ok := operand.([]interface{})
if !ok {
return ""
}
result := ""
for _, item := range arr {
if str, ok := item.(string); ok {
result += str
} else {
result += FormatValueToString(item)
}
}
return result
}
// substr 子字符串
func (e *AggregationEngine) substr(operand interface{}, data map[string]interface{}) string {
arr, ok := operand.([]interface{})
if !ok || len(arr) < 2 {
return ""
}
str := GetFieldValueStr(types.Document{Data: data}, arr[0])
start := int(ToFloat64(arr[1]))
if start < 0 {
start = 0
}
if start >= len(str) {
return ""
}
end := len(str)
if len(arr) > 2 {
length := int(ToFloat64(arr[2]))
if length > 0 {
end = start + length
if end > len(str) {
end = len(str)
}
}
}
return str[start:end]
}
// add 加法
func (e *AggregationEngine) add(operand interface{}, data map[string]interface{}) float64 {
arr, ok := operand.([]interface{})
if !ok {
return 0
}
sum := 0.0
for _, item := range arr {
sum += ToFloat64(e.evaluateExpression(data, item))
}
return sum
}
// multiply 乘法
func (e *AggregationEngine) multiply(operand interface{}, data map[string]interface{}) float64 {
arr, ok := operand.([]interface{})
if !ok {
return 0
}
product := 1.0
for _, item := range arr {
product *= ToFloat64(e.evaluateExpression(data, item))
}
return product
}
// divide 除法
func (e *AggregationEngine) divide(operand interface{}, data map[string]interface{}) float64 {
arr, ok := operand.([]interface{})
if !ok || len(arr) < 2 {
return 0
}
dividend := ToFloat64(e.evaluateExpression(data, arr[0]))
divisor := ToFloat64(e.evaluateExpression(data, arr[1]))
if divisor == 0 {
return 0
}
return dividend / divisor
}
// ifNull IF NULL 表达式
func (e *AggregationEngine) ifNull(operand interface{}, data map[string]interface{}) interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) < 2 {
return nil
}
value := e.evaluateExpression(data, arr[0])
if value == nil {
return e.evaluateExpression(data, arr[1])
}
return value
}
// cond 条件表达式
func (e *AggregationEngine) cond(operand interface{}, data map[string]interface{}) interface{} {
switch op := operand.(type) {
case map[string]interface{}:
ifCond, ok1 := op["if"]
thenCond, ok2 := op["then"]
elseCond, ok3 := op["else"]
if ok1 && ok2 && ok3 {
if IsTrueValue(e.evaluateExpression(data, ifCond)) {
return thenCond
}
return elseCond
}
case []interface{}:
if len(op) >= 3 {
if IsTrueValue(e.evaluateExpression(data, op[0])) {
return op[1]
}
return op[2]
}
}
return nil
}
// switchExpr $switch 条件分支
func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]interface{}) interface{} {
spec, ok := operand.(map[string]interface{})
if !ok {
return nil
}
branchesRaw, _ := spec["branches"].([]interface{})
defaultVal := spec["default"]
for _, branchRaw := range branchesRaw {
branch, ok := branchRaw.(map[string]interface{})
if !ok {
continue
}
caseRaw, _ := branch["case"]
thenRaw, _ := branch["then"]
if IsTrueValue(e.evaluateExpression(data, caseRaw)) {
return e.evaluateExpression(data, thenRaw)
}
}
return defaultVal
}
// getFieldValueStr 获取字段值的字符串形式(已移到 helpers.go此处为向后兼容
func (e *AggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
return GetFieldValueStr(doc, field)
}
// executeAddFields 执行 $addFields / $set 阶段
func (e *AggregationEngine) executeAddFields(spec interface{}, docs []types.Document) ([]types.Document, error) {
fields, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
newData := DeepCopyMap(doc.Data)
for field, expr := range fields {
newData[field] = e.evaluateExpression(newData, expr)
}
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeUnset 执行 $unset 阶段
func (e *AggregationEngine) executeUnset(spec interface{}, docs []types.Document) ([]types.Document, error) {
var fields []string
switch v := spec.(type) {
case string:
fields = []string{v}
case []interface{}:
fields = make([]string, 0, len(v))
for _, f := range v {
if fs, ok := f.(string); ok {
fields = append(fields, fs)
}
}
default:
return docs, nil
}
var results []types.Document
for _, doc := range docs {
newData := DeepCopyMap(doc.Data)
for _, field := range fields {
RemoveNestedValue(newData, field)
}
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeFacet 执行 $facet 阶段
func (e *AggregationEngine) executeFacet(spec interface{}, docs []types.Document) ([]types.Document, error) {
facets, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
result := make(map[string]interface{})
for facetName, pipelineRaw := range facets {
if pipeline, ok := pipelineRaw.([]interface{}); ok {
stages := make([]types.AggregateStage, 0, len(pipeline))
for _, stage := range pipeline {
if stageMap, ok := stage.(map[string]interface{}); ok {
for name, spec := range stageMap {
stages = append(stages, types.AggregateStage{
Stage: name,
Spec: spec,
})
break
}
}
}
facetResult, err := e.ExecutePipeline(docs, stages)
if err != nil {
return nil, err
}
result[facetName] = facetResult
}
}
return []types.Document{{
ID: "facet",
Data: result,
}}, nil
}
// executeSample 执行 $sample 阶段
func (e *AggregationEngine) executeSample(spec interface{}, docs []types.Document) ([]types.Document, error) {
size := 0
switch s := spec.(type) {
case map[string]interface{}:
if sizeVal, ok := s["size"]; ok {
size = int(ToFloat64(sizeVal))
}
case float64:
size = int(s)
default:
return docs, nil
}
if size <= 0 || size >= len(docs) {
return docs, nil
}
// Fisher-Yates 洗牌算法
shuffled := make([]types.Document, len(docs))
copy(shuffled, docs)
for i := len(shuffled) - 1; i > 0; i-- {
j := rand.Intn(i + 1)
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
}
return shuffled[:size], nil
}
// executeBucket 执行 $bucket 阶段
func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Document) ([]types.Document, error) {
bucketSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
groupBy, _ := bucketSpec["groupBy"].(string)
boundariesRaw, _ := bucketSpec["boundaries"].([]interface{})
defaultVal := bucketSpec["default"]
// 转换边界为 float64 数组
boundaries := make([]float64, 0, len(boundariesRaw))
for _, b := range boundariesRaw {
boundaries = append(boundaries, ToFloat64(b))
}
// 创建桶
buckets := make(map[string][]types.Document)
for i := 0; i < len(boundaries)-1; i++ {
bucketName := fmt.Sprintf("%v-%v", boundaries[i], boundaries[i+1])
buckets[bucketName] = make([]types.Document, 0)
}
if defaultVal != nil {
buckets[fmt.Sprintf("%v", defaultVal)] = make([]types.Document, 0)
}
// 分组
for _, doc := range docs {
value := ToFloat64(GetNestedValue(doc.Data, groupBy))
bucketName := ""
for i := 0; i < len(boundaries)-1; i++ {
if value >= boundaries[i] && value < boundaries[i+1] {
bucketName = fmt.Sprintf("%v-%v", boundaries[i], boundaries[i+1])
break
}
}
if bucketName == "" && defaultVal != nil {
bucketName = fmt.Sprintf("%v", defaultVal)
}
if bucketName != "" {
buckets[bucketName] = append(buckets[bucketName], doc)
}
}
// 构建结果
var results []types.Document
for bucketName, bucketDocs := range buckets {
result := map[string]interface{}{
"_id": bucketName,
"count": len(bucketDocs),
}
results = append(results, types.Document{
ID: bucketName,
Data: result,
})
}
return results, nil
}
// ExecutePipeline 执行管道(用于 $facet
func (e *AggregationEngine) ExecutePipeline(docs []types.Document, pipeline []types.AggregateStage) ([]types.Document, error) {
result := docs
for _, stage := range pipeline {
var err error
result, err = e.executeStage(stage, result)
if err != nil {
return nil, err
}
}
return result, nil
}
// ========== 算术表达式操作符 ==========
// abs 绝对值
func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}) float64 {
val := ToFloat64(e.evaluateExpression(data, operand))
if val < 0 {
return -val
}
return val
}
// ceil 向上取整
func (e *AggregationEngine) ceil(operand interface{}, data map[string]interface{}) float64 {
val := ToFloat64(e.evaluateExpression(data, operand))
return math.Ceil(val)
}
// floor 向下取整
func (e *AggregationEngine) floor(operand interface{}, data map[string]interface{}) float64 {
val := ToFloat64(e.evaluateExpression(data, operand))
return math.Floor(val)
}
// round 四舍五入
func (e *AggregationEngine) round(operand interface{}, data map[string]interface{}) float64 {
var value float64
var precision int
switch op := operand.(type) {
case []interface{}:
value = ToFloat64(e.evaluateExpression(data, op[0]))
if len(op) > 1 {
precision = int(ToFloat64(op[1]))
} else {
precision = 0
}
default:
value = ToFloat64(e.evaluateExpression(data, op))
precision = 0
}
return RoundToPrecision(value, precision)
}
// sqrt 平方根
func (e *AggregationEngine) sqrt(operand interface{}, data map[string]interface{}) float64 {
val := ToFloat64(e.evaluateExpression(data, operand))
return math.Sqrt(val)
}
// subtract 减法
func (e *AggregationEngine) subtract(operand interface{}, data map[string]interface{}) float64 {
arr, ok := operand.([]interface{})
if !ok || len(arr) < 2 {
return 0
}
result := ToFloat64(e.evaluateExpression(data, arr[0]))
for i := 1; i < len(arr); i++ {
result -= ToFloat64(e.evaluateExpression(data, arr[i]))
}
return result
}
// pow 幂运算
func (e *AggregationEngine) pow(operand interface{}, data map[string]interface{}) float64 {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return 0
}
base := ToFloat64(e.evaluateExpression(data, arr[0]))
exp := ToFloat64(e.evaluateExpression(data, arr[1]))
return math.Pow(base, exp)
}
// ========== 字符串表达式操作符 ==========
// trim 去除两端空格
func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{}) string {
var input string
var chars string = " "
switch op := operand.(type) {
case map[string]interface{}:
if in, ok := op["input"]; ok {
input = GetFieldValueStr(types.Document{Data: data}, in)
}
if c, ok := op["characters"]; ok {
chars = c.(string)
}
case string:
input = GetFieldValueStr(types.Document{Data: data}, op)
default:
input = FormatValueToString(operand)
}
return strings.Trim(input, chars)
}
// ltrim 去除左侧空格
func (e *AggregationEngine) ltrim(operand interface{}, data map[string]interface{}) string {
input := GetFieldValueStr(types.Document{Data: data}, operand)
return strings.TrimLeft(input, " ")
}
// rtrim 去除右侧空格
func (e *AggregationEngine) rtrim(operand interface{}, data map[string]interface{}) string {
input := GetFieldValueStr(types.Document{Data: data}, operand)
return strings.TrimRight(input, " ")
}
// split 分割字符串
func (e *AggregationEngine) split(operand interface{}, data map[string]interface{}) []interface{} {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return nil
}
input := GetFieldValueStr(types.Document{Data: data}, arr[0])
delimiter := arr[1].(string)
parts := strings.Split(input, delimiter)
result := make([]interface{}, len(parts))
for i, part := range parts {
result[i] = part
}
return result
}
// replaceAll 替换所有匹配
func (e *AggregationEngine) replaceAll(operand interface{}, data map[string]interface{}) string {
spec, ok := operand.(map[string]interface{})
if !ok {
return ""
}
input := GetFieldValueStr(types.Document{Data: data}, spec["input"])
find := spec["find"].(string)
replacement := ""
if rep, ok := spec["replacement"]; ok {
replacement = FormatValueToString(rep)
}
return strings.ReplaceAll(input, find, replacement)
}
// strcasecmp 不区分大小写比较
func (e *AggregationEngine) strcasecmp(operand interface{}, data map[string]interface{}) int {
arr, ok := operand.([]interface{})
if !ok || len(arr) != 2 {
return 0
}
str1 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[0]))
str2 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[1]))
if str1 < str2 {
return -1
} else if str1 > str2 {
return 1
}
return 0
}
// ========== 集合表达式操作符 ==========
// filter 过滤数组
func (e *AggregationEngine) filter(operand interface{}, data map[string]interface{}) []interface{} {
spec, ok := operand.(map[string]interface{})
if !ok {
return nil
}
inputRaw, _ := spec["input"]
input := e.toArray(inputRaw)
as, _ := spec["as"].(string)
if as == "" {
as = "item"
}
condRaw, _ := spec["cond"]
var result []interface{}
for _, item := range input {
tempData := make(map[string]interface{})
for k, v := range data {
tempData[k] = v
}
tempData["$$"+as] = item
if IsTrueValue(e.evaluateExpression(tempData, condRaw)) {
result = append(result, item)
}
}
return result
}
// map 映射数组
func (e *AggregationEngine) mapArr(operand interface{}, data map[string]interface{}) []interface{} {
spec, ok := operand.(map[string]interface{})
if !ok {
return nil
}
inputRaw, _ := spec["input"]
input := e.toArray(inputRaw)
as, _ := spec["as"].(string)
if as == "" {
as = "item"
}
inRaw, _ := spec["in"]
var result []interface{}
for _, item := range input {
tempData := make(map[string]interface{})
for k, v := range data {
tempData[k] = v
}
tempData["$$"+as] = item
result = append(result, e.evaluateExpression(tempData, inRaw))
}
return result
}
// concatArrays 连接数组
func (e *AggregationEngine) concatArrays(operand interface{}, data map[string]interface{}) []interface{} {
arr, ok := operand.([]interface{})
if !ok {
return nil
}
var result []interface{}
for _, a := range arr {
if array := e.toArray(a); array != nil {
result = append(result, array...)
}
}
return result
}
// slice 截取数组
func (e *AggregationEngine) slice(operand interface{}, data map[string]interface{}) []interface{} {
var arr []interface{}
var skip int
var limit int
switch op := operand.(type) {
case []interface{}:
if len(op) >= 2 {
arr = e.toArray(op[0])
skip = int(ToFloat64(op[1]))
if len(op) > 2 {
limit = int(ToFloat64(op[2]))
} else {
limit = len(arr) - skip
}
}
}
if arr == nil || skip < 0 {
return nil
}
if skip >= len(arr) {
return []interface{}{}
}
end := skip + limit
if end > len(arr) {
end = len(arr)
}
return arr[skip:end]
}
// ========== 对象表达式操作符 ==========
// mergeObjects 合并对象
func (e *AggregationEngine) mergeObjects(operand interface{}, data map[string]interface{}) map[string]interface{} {
arr, ok := operand.([]interface{})
if !ok {
return nil
}
result := make(map[string]interface{})
for _, obj := range arr {
if m, ok := obj.(map[string]interface{}); ok {
for k, v := range m {
result[k] = v
}
}
}
return result
}
// objectToArray 对象转数组
func (e *AggregationEngine) objectToArray(operand interface{}, data map[string]interface{}) []interface{} {
obj, ok := operand.(map[string]interface{})
if !ok {
return nil
}
result := make([]interface{}, 0, len(obj))
for k, v := range obj {
result = append(result, map[string]interface{}{
"k": k,
"v": v,
})
}
return result
}
// ========== 辅助函数 ==========
// toArray 将值转换为数组(保持向后兼容)
func (e *AggregationEngine) toArray(value interface{}) []interface{} {
switch v := value.(type) {
case []interface{}:
return v
case map[string]interface{}:
// 如果是文档,返回 nil
return nil
default:
// 单个值包装为数组
return []interface{}{v}
}
}
// boolAnd 布尔与
func (e *AggregationEngine) boolAnd(operand interface{}, data map[string]interface{}) bool {
arr, ok := operand.([]interface{})
if !ok {
return false
}
for _, item := range arr {
if !IsTrueValue(e.evaluateExpression(data, item)) {
return false
}
}
return true
}
// boolOr 布尔或
func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interface{}) bool {
arr, ok := operand.([]interface{})
if !ok {
return false
}
for _, item := range arr {
if IsTrueValue(e.evaluateExpression(data, item)) {
return true
}
}
return false
}
// boolNot 布尔非
func (e *AggregationEngine) boolNot(operand interface{}, data map[string]interface{}) bool {
return !IsTrueValue(e.evaluateExpression(data, operand))
}