package engine import ( "strings" "time" "git.kingecg.top/kingecg/gomog/pkg/types" ) // applyUpdate 应用更新操作到文档数据 func applyUpdate(data map[string]interface{}, update types.Update, isUpsertInsert bool) map[string]interface{} { return applyUpdateWithFilters(data, update, isUpsertInsert, nil) } // applyUpdateWithFilters 应用更新操作(支持 arrayFilters) func applyUpdateWithFilters(data map[string]interface{}, update types.Update, isUpsertInsert bool, arrayFilters []types.Filter) map[string]interface{} { // 深拷贝原数据 result := deepCopyMap(data) // 处理 $set for field, value := range update.Set { if !updateArrayElement(result, field, value, convertFiltersToMaps(arrayFilters)) { setNestedValue(result, field, value) } } // 处理 $unset for field := range update.Unset { removeNestedValue(result, field) } // 处理 $inc for field, value := range update.Inc { if !updateArrayElement(result, field, value, convertFiltersToMaps(arrayFilters)) { incNestedValue(result, field, value) } } // 处理 $mul for field, value := range update.Mul { if !updateArrayElement(result, field, value, convertFiltersToMaps(arrayFilters)) { mulNestedValue(result, field, value) } } // 处理 $push for field, value := range update.Push { pushNestedValue(result, field, value) } // 处理 $pull for field, value := range update.Pull { pullNestedValue(result, field, value) } // 处理 $min - 仅当值小于当前值时更新 for field, value := range update.Min { current := getNestedValue(result, field) if current == nil || compareNumbers(current, value) > 0 { setNestedValue(result, field, value) } } // 处理 $max - 仅当值大于当前值时更新 for field, value := range update.Max { current := getNestedValue(result, field) if current == nil || compareNumbers(current, value) < 0 { setNestedValue(result, field, value) } } // 处理 $rename - 重命名字段 for oldName, newName := range update.Rename { value := getNestedValue(result, oldName) if value != nil { removeNestedValue(result, oldName) setNestedValue(result, newName, value) } } // 处理 $currentDate - 设置为当前时间 for field, spec := range update.CurrentDate { var currentTime interface{} = time.Now() // 检查是否指定了类型 if specMap, ok := spec.(map[string]interface{}); ok { if typeVal, exists := specMap["$type"]; exists { if typeStr, ok := typeVal.(string); ok && typeStr == "timestamp" { currentTime = time.Now().UnixMilli() } } } setNestedValue(result, field, currentTime) } // 处理 $addToSet - 添加唯一元素到数组 for field, value := range update.AddToSet { current := getNestedValue(result, field) var arr []interface{} if current != nil { if a, ok := current.([]interface{}); ok { arr = a } } if arr == nil { arr = make([]interface{}, 0) } // 检查是否已存在 exists := false for _, item := range arr { if compareEq(item, value) { exists = true break } } if !exists { arr = append(arr, value) setNestedValue(result, field, arr) } } // 处理 $pop - 移除数组首/尾元素 for field, pos := range update.Pop { current := getNestedValue(result, field) if arr, ok := current.([]interface{}); ok && len(arr) > 0 { if pos >= 0 { // 移除最后一个元素 arr = arr[:len(arr)-1] } else { // 移除第一个元素 arr = arr[1:] } setNestedValue(result, field, arr) } } // 处理 $pullAll - 从数组中移除多个值 for field, values := range update.PullAll { current := getNestedValue(result, field) if arr, ok := current.([]interface{}); ok { filtered := make([]interface{}, 0, len(arr)) for _, item := range arr { keep := true for _, removeVal := range values { if compareEq(item, removeVal) { keep = false break } } if keep { filtered = append(filtered, item) } } setNestedValue(result, field, filtered) } } // 处理 $setOnInsert - 仅在 upsert 插入时设置 if isUpsertInsert { for field, value := range update.SetOnInsert { setNestedValue(result, field, value) } } return result } // convertFiltersToMaps 转换 Filter 数组为 map 数组 func convertFiltersToMaps(filters []types.Filter) []map[string]interface{} { if filters == nil { return nil } result := make([]map[string]interface{}, len(filters)) for i, f := range filters { result[i] = map[string]interface{}(f) } return result } // deepCopyMap 深拷贝 map func deepCopyMap(m map[string]interface{}) map[string]interface{} { if m == nil { return nil } result := make(map[string]interface{}) for k, v := range m { switch val := v.(type) { case map[string]interface{}: result[k] = deepCopyMap(val) case []interface{}: result[k] = deepCopySlice(val) default: result[k] = v } } return result } // deepCopySlice 深拷贝 slice func deepCopySlice(s []interface{}) []interface{} { if s == nil { return nil } result := make([]interface{}, len(s)) for i, v := range s { switch val := v.(type) { case map[string]interface{}: result[i] = deepCopyMap(val) case []interface{}: result[i] = deepCopySlice(val) default: result[i] = v } } return result } // setNestedValue 设置嵌套字段值 func setNestedValue(data map[string]interface{}, field string, value interface{}) { parts := splitFieldPath(field) current := data for i, part := range parts { if i == len(parts)-1 { // 最后一个部分,设置值 current[part] = value return } // 中间部分,确保是 map if current[part] == nil { current[part] = make(map[string]interface{}) } if m, ok := current[part].(map[string]interface{}); ok { current = m } else { // 类型不匹配,创建新 map newMap := make(map[string]interface{}) current[part] = newMap current = newMap } } } // removeNestedValue 删除嵌套字段 func removeNestedValue(data map[string]interface{}, field string) { parts := splitFieldPath(field) current := data for i, part := range parts { if i == len(parts)-1 { delete(current, part) return } if m, ok := current[part].(map[string]interface{}); ok { current = m } else { return } } } // incNestedValue 递增嵌套字段值 func incNestedValue(data map[string]interface{}, field string, increment interface{}) { current := getNestedValue(data, field) if current == nil { setNestedValue(data, field, increment) return } newValue := toFloat64(current) + toFloat64(increment) setNestedValue(data, field, newValue) } // mulNestedValue 乘以嵌套字段值 func mulNestedValue(data map[string]interface{}, field string, multiplier interface{}) { current := getNestedValue(data, field) if current == nil { return } newValue := toFloat64(current) * toFloat64(multiplier) setNestedValue(data, field, newValue) } // pushNestedValue 推入数组 func pushNestedValue(data map[string]interface{}, field string, value interface{}) { current := getNestedValue(data, field) var arr []interface{} if current != nil { if a, ok := current.([]interface{}); ok { arr = a } } if arr == nil { arr = make([]interface{}, 0) } arr = append(arr, value) setNestedValue(data, field, arr) } // pullNestedValue 从数组中移除 func pullNestedValue(data map[string]interface{}, field string, value interface{}) { current := getNestedValue(data, field) if current == nil { return } arr, ok := current.([]interface{}) if !ok { return } // 过滤掉匹配的值 filtered := make([]interface{}, 0, len(arr)) for _, item := range arr { if !compareEq(item, value) { filtered = append(filtered, item) } } setNestedValue(data, field, filtered) } // splitFieldPath 分割字段路径(支持 "a.b.c" 格式) func splitFieldPath(field string) []string { // 简单实现,不考虑转义情况 return strings.Split(field, ".") } // generateID 生成唯一 ID func generateID() string { return time.Now().Format("20060102150405.000000000") } // updateArrayElement 更新数组元素(支持 $ 位置操作符) func updateArrayElement(data map[string]interface{}, field string, value interface{}, arrayFilters []map[string]interface{}) bool { parts := splitFieldPath(field) // 查找包含 $ 或 $[] 的部分 for i, part := range parts { if part == "$" || part == "$[]" || (len(part) > 2 && part[0] == '$' && part[1] == '[') { // 需要数组更新 return updateArrayAtPath(data, parts, i, value, arrayFilters) } } // 普通字段更新 setNestedValue(data, field, value) return true } // updateArrayAtPath 在指定路径更新数组 func updateArrayAtPath(data map[string]interface{}, parts []string, index int, value interface{}, arrayFilters []map[string]interface{}) bool { // 获取到数组前的路径(导航到父对象) current := data for i := 0; i < index; i++ { if m, ok := current[parts[i]].(map[string]interface{}); ok { current = m } else if i == index-1 { // 最后一个部分应该是数组字段名,不需要是 map break } else { return false } } // 获取实际的数组字段名(操作符前面的部分) var actualFieldName string if index > 0 { actualFieldName = parts[index-1] } else { return false // 无效的路径 } arrField := parts[index] arr := getNestedValue(data, actualFieldName) array, ok := arr.([]interface{}) if !ok || len(array) == 0 { return false } // 处理不同的位置操作符 if arrField == "$" { // 定位第一个匹配的元素(需要配合查询条件) // 简化实现:更新第一个元素 array[0] = value setNestedValue(data, actualFieldName, array) return true } if arrField == "$[]" { // 更新所有元素 for i := range array { array[i] = value } setNestedValue(data, actualFieldName, array) return true } // 处理 $[identifier] 形式 if len(arrField) > 3 && arrField[0] == '$' && arrField[1] == '[' && arrField[len(arrField)-1] == ']' { identifier := arrField[2 : len(arrField)-1] // 查找匹配的 arrayFilter var filter map[string]interface{} for _, f := range arrayFilters { if idVal, exists := f["identifier"]; exists && idVal == identifier { // 复制 filter 并移除 identifier 字段 filter = make(map[string]interface{}) for k, v := range f { if k != "identifier" { filter[k] = v } } break } } if filter != nil && len(filter) > 0 { // 应用过滤器更新匹配的元素 for i, item := range array { if itemMap, ok := item.(map[string]interface{}); ok { if MatchFilter(itemMap, filter) { // 如果是嵌套字段(如 students.$[elem].grade),需要设置嵌套字段 if index+1 < len(parts) { // 还有后续字段,设置嵌套字段 itemMap[parts[index+1]] = value } else { array[i] = value } } } } setNestedValue(data, actualFieldName, array) return true } } return false }