package database import ( "context" "database/sql" "encoding/json" "fmt" "time" "git.kingecg.top/kingecg/gomog/pkg/types" ) // BaseAdapter 基础适配器实现 type BaseAdapter struct { db *sql.DB driverName string } // NewBaseAdapter 创建基础适配器 func NewBaseAdapter(driverName string) *BaseAdapter { return &BaseAdapter{ driverName: driverName, } } // getDB 获取数据库连接(供子类使用) func (a *BaseAdapter) GetDB() *sql.DB { return a.db } // Connect 连接数据库 func (a *BaseAdapter) Connect(ctx context.Context, dsn string) error { db, err := sql.Open(a.driverName, dsn) if err != nil { return err } a.db = db return db.PingContext(ctx) } // Close 关闭连接 func (a *BaseAdapter) Close() error { if a.db != nil { return a.db.Close() } return nil } // Ping 检查连接 func (a *BaseAdapter) Ping(ctx context.Context) error { return a.db.PingContext(ctx) } // CreateCollection 创建集合(表) func (a *BaseAdapter) CreateCollection(ctx context.Context, name string) error { // 使用统一的表结构:id, data(JSON), created_at, updated_at query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id TEXT PRIMARY KEY, data JSON NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`, name) _, err := a.db.ExecContext(ctx, query) return err } // DropCollection 删除集合(表) func (a *BaseAdapter) DropCollection(ctx context.Context, name string) error { query := fmt.Sprintf("DROP TABLE IF EXISTS %s", name) _, err := a.db.ExecContext(ctx, query) return err } // CollectionExists 检查集合是否存在 func (a *BaseAdapter) CollectionExists(ctx context.Context, name string) (bool, error) { // 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同 return false, ErrNotImplemented } // InsertMany 批量插入文档 func (a *BaseAdapter) InsertMany(ctx context.Context, collection string, docs []types.Document) error { tx, err := a.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() stmt, err := tx.PrepareContext(ctx, fmt.Sprintf("INSERT INTO %s (id, data, created_at, updated_at) VALUES (?, ?, ?, ?)", collection)) if err != nil { return err } defer stmt.Close() for _, doc := range docs { jsonData, err := json.Marshal(doc.Data) if err != nil { return err } now := time.Now() _, err = stmt.ExecContext(ctx, doc.ID, jsonData, now, now) if err != nil { return err } } return tx.Commit() } // UpdateMany 批量更新文档 func (a *BaseAdapter) UpdateMany(ctx context.Context, collection string, ids []string, update types.Update) error { tx, err := a.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() // 构建更新 SQL setClauses := make([]string, 0) args := make([]interface{}, 0) // 处理 $set for field, value := range update.Set { setClauses = append(setClauses, fmt.Sprintf("json_set(data, '$.%s', ?)", field)) args = append(args, toJSONString(value)) } // 处理 $unset for field := range update.Unset { // SQLite/PostgreSQL 移除 JSON 字段的方式不同,这里简化处理 // 实际实现中需要根据具体数据库调整 setClauses = append(setClauses, fmt.Sprintf("json_remove(data, '$.%s')", field)) } if len(setClauses) == 0 { return nil } // 为每个 ID 执行更新 for _, id := range ids { updateArgs := append([]interface{}{time.Now()}, args...) updateArgs = append(updateArgs, id) query := fmt.Sprintf( "UPDATE %s SET data = %s, updated_at = ? WHERE id = ?", collection, setClauses[0], // 简化:只处理第一个 set 子句 ) _, err = tx.ExecContext(ctx, query, updateArgs...) if err != nil { return err } } return tx.Commit() } // DeleteMany 批量删除文档 func (a *BaseAdapter) DeleteMany(ctx context.Context, collection string, ids []string) error { if len(ids) == 0 { return nil } // 构建 IN 子句 placeholders := make([]string, len(ids)) args := make([]interface{}, len(ids)) for i, id := range ids { placeholders[i] = "?" args[i] = id } query := fmt.Sprintf( "DELETE FROM %s WHERE id IN (%s)", collection, fmt.Sprintf("%s", placeholders), ) _, err := a.db.ExecContext(ctx, query, args...) return err } // FindAll 查询所有文档 func (a *BaseAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) { query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection) rows, err := a.db.QueryContext(ctx, query) if err != nil { return nil, err } defer rows.Close() var docs []types.Document for rows.Next() { var doc types.Document var jsonData []byte err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt) if err != nil { return nil, err } if err := json.Unmarshal(jsonData, &doc.Data); err != nil { return nil, err } docs = append(docs, doc) } return docs, rows.Err() } // BeginTx 开始事务 func (a *BaseAdapter) BeginTx(ctx context.Context) (Transaction, error) { tx, err := a.db.BeginTx(ctx, nil) if err != nil { return nil, err } return &baseTransaction{tx: tx}, nil } // baseTransaction 基础事务实现 type baseTransaction struct { tx *sql.Tx } func (t *baseTransaction) Commit() error { return t.tx.Commit() } func (t *baseTransaction) Rollback() error { return t.tx.Rollback() } // ListCollections 获取所有集合(表)列表 func (a *BaseAdapter) ListCollections(ctx context.Context) ([]string, error) { // 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同 return nil, ErrNotImplemented } // toJSONString 将值转换为 JSON 字符串 func toJSONString(v interface{}) string { if v == nil { return "null" } data, _ := json.Marshal(v) return string(data) }