gomog/internal/engine/aggregate_batch5.go

369 lines
9.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package engine
import (
"encoding/json"
"fmt"
"git.kingecg.top/kingecg/gomog/pkg/errors"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// 特殊红黑标记(用于 $redact
const (
RedactDescend = "$$DESCEND"
RedactPrune = "$$PRUNE"
RedactKeep = "$$KEEP"
)
// executeUnionWith 执行 $unionWith 阶段
func (e *AggregationEngine) executeUnionWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
var collection string
var pipelineStages []types.AggregateStage
// 解析 spec支持字符串和对象两种形式
switch s := spec.(type) {
case string:
// 简写形式:{ $unionWith: "collection" }
collection = s
pipelineStages = []types.AggregateStage{}
case map[string]interface{}:
// 完整形式:{ $unionWith: { coll: "...", pipeline: [...] } }
coll, ok := s["coll"].(string)
if !ok {
return docs, nil
}
collection = coll
// 解析 pipeline
pipelineRaw, _ := s["pipeline"].([]interface{})
for _, stageRaw := range pipelineRaw {
stageMap, ok := stageRaw.(map[string]interface{})
if !ok {
continue
}
for stageName, stageSpec := range stageMap {
pipelineStages = append(pipelineStages, types.AggregateStage{
Stage: stageName,
Spec: stageSpec,
})
break
}
}
default:
return docs, nil
}
// 获取并集集合的所有文档
unionDocs, err := e.store.GetAllDocuments(collection)
if err != nil {
// 集合不存在返回空数组
unionDocs = []types.Document{}
}
// 如果指定了 pipeline对并集数据执行 pipeline
if len(pipelineStages) > 0 {
unionDocs, err = e.ExecutePipeline(unionDocs, pipelineStages)
if err != nil {
return nil, errors.Wrap(err, errors.ErrAggregationError, "failed to execute union pipeline")
}
}
// 合并原文档和并集文档
result := make([]types.Document, 0, len(docs)+len(unionDocs))
result = append(result, docs...)
result = append(result, unionDocs...)
return result, nil
}
// executeRedact 执行 $redact 阶段
func (e *AggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) {
var results []types.Document
for _, doc := range docs {
redactedData, keep := e.redactDocument(doc.Data, spec)
if keep {
results = append(results, types.Document{
ID: doc.ID,
Data: redactedData.(map[string]interface{}),
})
}
}
return results, nil
}
// redactDocument 递归处理文档的红黑
func (e *AggregationEngine) redactDocument(data interface{}, spec interface{}) (interface{}, bool) {
// 评估红黑表达式
dataMap, ok := data.(map[string]interface{})
if !ok {
return data, true
}
result := e.evaluateExpression(dataMap, spec)
// 根据结果决定行为
switch result {
case RedactKeep:
return data, true
case RedactPrune:
return nil, false
case RedactDescend:
// 继续处理嵌套结构
return e.redactNested(data, spec)
default:
// 默认继续 descend
return e.redactNested(data, spec)
}
}
// redactNested 递归处理嵌套文档和数组
func (e *AggregationEngine) redactNested(data interface{}, spec interface{}) (interface{}, bool) {
switch d := data.(type) {
case map[string]interface{}:
return e.redactMap(d, spec)
case []interface{}:
return e.redactArray(d, spec)
default:
return data, true
}
}
func (e *AggregationEngine) redactMap(m map[string]interface{}, spec interface{}) (map[string]interface{}, bool) {
result := make(map[string]interface{})
for k, v := range m {
fieldResult, keep := e.redactDocument(v, spec)
if keep {
result[k] = fieldResult
}
}
return result, true
}
func (e *AggregationEngine) redactArray(arr []interface{}, spec interface{}) ([]interface{}, bool) {
result := make([]interface{}, 0)
for _, item := range arr {
itemResult, keep := e.redactDocument(item, spec)
if keep {
result = append(result, itemResult)
}
}
return result, true
}
// executeOut 执行 $out 阶段
func (e *AggregationEngine) executeOut(spec interface{}, docs []types.Document, currentCollection string) ([]types.Document, error) {
var targetCollection string
// 解析 spec支持字符串和对象两种形式
switch s := spec.(type) {
case string:
targetCollection = s
case map[string]interface{}:
// 支持 { db: "...", coll: "..." } 形式
if db, ok := s["db"].(string); ok && db != "" {
targetCollection = db + "." + s["coll"].(string)
} else {
targetCollection = s["coll"].(string)
}
default:
return nil, errors.New(errors.ErrInvalidRequest, "invalid $out specification")
}
// 删除目标集合的现有数据(如果有)
err := e.store.DropCollection(targetCollection)
if err != nil && err != errors.ErrCollectionNotFnd {
return nil, errors.Wrap(err, errors.ErrDatabaseError, "failed to drop target collection")
}
// 创建新集合并插入所有文档
for _, doc := range docs {
err := e.store.InsertDocument(targetCollection, doc)
if err != nil {
return nil, errors.Wrap(err, errors.ErrDatabaseError, "failed to insert document")
}
}
// 返回确认文档
return []types.Document{{
Data: map[string]interface{}{
"ok": float64(1),
"nInserted": float64(len(docs)),
"targetCollection": targetCollection,
},
}}, nil
}
// executeMerge 执行 $merge 阶段
func (e *AggregationEngine) executeMerge(spec interface{}, docs []types.Document, currentCollection string) ([]types.Document, error) {
// 解析 spec
mergeSpec, ok := spec.(map[string]interface{})
if !ok {
return nil, errors.New(errors.ErrInvalidRequest, "invalid $merge specification")
}
// 获取目标集合名
var targetCollection string
switch into := mergeSpec["into"].(type) {
case string:
targetCollection = into
case map[string]interface{}:
targetCollection = into["coll"].(string)
default:
return nil, errors.New(errors.ErrInvalidRequest, "invalid $merge into specification")
}
// 获取匹配字段(默认 _id
onField, _ := mergeSpec["on"].(string)
if onField == "" {
onField = "_id"
}
// 获取匹配策略
whenMatched, _ := mergeSpec["whenMatched"].(string)
if whenMatched == "" {
whenMatched = "replace"
}
whenNotMatched, _ := mergeSpec["whenNotMatched"].(string)
if whenNotMatched == "" {
whenNotMatched = "insert"
}
// 获取目标集合现有文档
existingDocs, _ := e.store.GetAllDocuments(targetCollection)
existingMap := make(map[string]types.Document)
for _, doc := range existingDocs {
key := getDocumentKey(doc, onField)
existingMap[key] = doc
}
// 统计信息
stats := map[string]float64{
"nInserted": 0,
"nUpdated": 0,
"nUnchanged": 0,
"nDeleted": 0,
}
// 处理每个输入文档
for _, doc := range docs {
key := getDocumentKey(doc, onField)
_, exists := existingMap[key]
if exists {
// 文档已存在
switch whenMatched {
case "replace":
e.store.UpdateDocument(targetCollection, doc)
stats["nUpdated"]++
case "keepExisting":
stats["nUnchanged"]++
case "merge":
// 合并字段
if existing, ok := existingMap[key]; ok {
mergedData := deepCopyMap(existing.Data)
for k, v := range doc.Data {
mergedData[k] = v
}
doc.Data = mergedData
e.store.UpdateDocument(targetCollection, doc)
stats["nUpdated"]++
}
case "fail":
return nil, errors.New(errors.ErrDuplicateKey, "document already exists")
case "delete":
// 删除已存在的文档
stats["nDeleted"]++
}
} else {
// 文档不存在
if whenNotMatched == "insert" {
e.store.InsertDocument(targetCollection, doc)
stats["nInserted"]++
}
}
}
// 返回统计信息
return []types.Document{{
Data: map[string]interface{}{
"ok": float64(1),
"nInserted": stats["nInserted"],
"nUpdated": stats["nUpdated"],
"nUnchanged": stats["nUnchanged"],
"nDeleted": stats["nDeleted"],
},
}}, nil
}
// getDocumentKey 获取文档的唯一键
func getDocumentKey(doc types.Document, keyField string) string {
if keyField == "_id" {
return doc.ID
}
value := getNestedValue(doc.Data, keyField)
if value == nil {
return ""
}
return fmt.Sprintf("%v", value)
}
// executeIndexStats 执行 $indexStats 阶段(简化版本)
func (e *AggregationEngine) executeIndexStats(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 返回模拟的索引统计信息
return []types.Document{{
Data: map[string]interface{}{
"name": "id_idx",
"key": map[string]interface{}{"_id": 1},
"accesses": map[string]interface{}{
"ops": float64(0),
"since": "2024-01-01T00:00:00Z",
},
},
}}, nil
}
// executeCollStats 执行 $collStats 阶段(简化版本)
func (e *AggregationEngine) executeCollStats(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 返回集合统计信息
return []types.Document{{
Data: map[string]interface{}{
"ns": "test.collection",
"count": float64(len(docs)),
"size": estimateSize(docs),
"storageSize": float64(0), // 内存存储无此概念
"nindexes": float64(1),
},
}}, nil
}
// estimateSize 估算文档大小(字节)
func estimateSize(docs []types.Document) float64 {
total := 0
for _, doc := range docs {
// JSON 序列化后的大小
data, _ := json.Marshal(doc.Data)
total += len(data)
}
return float64(total)
}