Open Source, Open Future!
  menu
120 文章
ღゝ◡╹)ノ❤️

当我让AI接管数据库:一个NL2SQL智能体的诞生记

一、写在前面

最近想做一个"数据库智能体"——就是用自然语言问数据库问题,AI 自动生成 SQL 并返回结果。

看起来挺简单的,但仔细想想,要做好还真不容易:

  • 怎么把自然语言转成准确的 SQL?
  • 怎么保证 SQL 安全(防注入、防误删数据)?
  • 怎么支持多轮对话("他们的平均年龄是多少"这种上下文问题)?
  • 怎么接入 MCP 协议让其他系统调用?

二、问题定义

2.1 具体场景

假设有这样一个数据库:

-- 用户表
users (id, name, email, age)

-- 订单表
orders (id, user_id, product_name, amount, order_date)

-- 商品表
products (id, name, price, stock, category)

用户想问:

  • "查询所有用户"
  • "张三买了什么?"
  • "哪些用户买过外设类商品?"
  • "每个用户的订单总金额是多少?"

2.2 想法1:写死 SQL 模板

最简单的想法是预定义一堆 SQL 模板:

templates = {
    "查询所有用户": "SELECT * FROM users",
    "查询订单": "SELECT * FROM orders WHERE user_id = ?",
    ...
}

问题:

  • 模板数量爆炸(每种问法都要写一个)
  • 无法处理复杂查询(跨表、聚合)
  • 不支持自然语言变化("所有用户" vs "全部用户")

2.3 想法2:让 LLM 直接生成 SQL

直接把问题扔给 GPT:

prompt = f"把这个问题转成 SQL: {question}"
sql = llm.generate(prompt)

问题:

  • LLM 不知道数据库结构(表名、列名)
  • 容易生成错误的 SQL
  • 没有安全控制

2.4 正确做法:NL2SQL + Schema + 安全控制

把问题建模成这样:

用户问题 + 数据库 Schema → LLM → SQL → 安全校验 → 执行 → 结果

关键点:

  1. Schema 提供:告诉 LLM 数据库有哪些表、哪些字段
  2. Prompt 工程:设计好的 Prompt 让 LLM 生成准确的 SQL
  3. 安全校验:防止 SQL 注入、防止误删数据
  4. 对话管理:维护上下文,支持多轮对话

三、技术栈

组件选型理由
框架Spring Boot 3.2成熟稳定
LLM 编排LangChain4j 0.35API 简洁,Java 生态
大模型通义千问 qwen-max国内访问快,效果好
数据库MySQL 8.0通用
连接池HikariCP高性能
SQL 解析JSqlParser安全校验用

四、核心实现

4.1 架构设计

整体分层:

HTTP API (Controller)
    ↓
对话管理 (ConversationService)
    ↓
NL2SQL Agent (LangChain4j + 通义千问)
    ↓
SQL 执行器 (DalExecutor)
    ↓
数据库 (MySQL)

4.2 Schema 自动获取

第一步是让 LLM 知道数据库结构。用 JDBC 的 DatabaseMetaData 自动读取:

@Component
public class SchemaProvider {

    private final DataSource dataSource;
    private final ConcurrentHashMap<String, String> schemaCache = new ConcurrentHashMap<>();

    public String getSchema() {
        return schemaCache.computeIfAbsent("schema", k -> loadSchema());
    }

    private String loadSchema() {
        StringBuilder schema = new StringBuilder();

        try (Connection conn = dataSource.getConnection()) {
            DatabaseMetaData metaData = conn.getMetaData();
            ResultSet tables = metaData.getTables(catalog, null, "%", new String[]{"TABLE"});

            while (tables.next()) {
                String tableName = tables.getString("TABLE_NAME");
                schema.append("表名: ").append(tableName).append("\n");

                ResultSet columns = metaData.getColumns(catalog, null, tableName, "%");
                while (columns.next()) {
                    String columnName = columns.getString("COLUMN_NAME");
                    String columnType = columns.getString("TYPE_NAME");
                    schema.append("  - ").append(columnName)
                          .append(" (").append(columnType).append(")\n");
                }
            }
        }

        return schema.toString();
    }
}

生成的 Schema 长这样:

表名: users
  - id (INT)
  - name (VARCHAR)
  - email (VARCHAR)
  - age (INT)

表名: orders
  - id (INT)
  - user_id (INT)
  - product_name (VARCHAR)
  - amount (DECIMAL)
  - order_date (DATE)

关键点:

  • ConcurrentHashMap 缓存,避免每次都查数据库
  • 格式化成 LLM 容易理解的文本

4.3 Prompt 工程

Prompt 是核心,直接影响 SQL 生成质量:

public static final String NL2SQL_TEMPLATE = """
    你是一个专业的SQL专家。根据用户的自然语言问题和数据库Schema,生成准确的SQL查询。

    数据库Schema:
    {schema}

    历史对话:
    {history}

    用户问题: {question}

    要求:
    1. 只生成SELECT语句,不要生成INSERT/UPDATE/DELETE/DROP等修改数据的语句
    2. 只返回SQL语句本身,不要包含任何解释或markdown格式
    3. SQL语句要准确、高效
    4. 如果问题无法用SQL回答,返回: ERROR: 无法生成SQL

    SQL:
    """;

关键点:

  • 明确告诉 LLM 只生成 SELECT(安全第一)
  • 提供完整的 Schema 信息
  • 包含历史对话(支持多轮对话)
  • 要求只返回 SQL,不要废话

4.4 NL2SQL Agent

核心逻辑很简洁:

@Component
public class NL2SQLAgent {

    private final ChatLanguageModel chatModel;
    private final SchemaProvider schemaProvider;
    private final PromptTemplate promptTemplate;
    private final DalExecutor dalExecutor;

    public QueryResult processQuestion(String question, String conversationHistory) {
        try {
            // 1. 获取 Schema
            String schema = schemaProvider.getSchema();

            // 2. 构建 Prompt
            String prompt = promptTemplate.buildPrompt(schema, conversationHistory, question);

            // 3. 调用 LLM 生成 SQL
            String response = chatModel.generate(prompt);
            String sql = extractSql(response);

            // 4. 执行 SQL
            return dalExecutor.executeQuery(sql);

        } catch (Exception e) {
            return QueryResult.builder()
                    .success(false)
                    .errorMessage("处理失败: " + e.getMessage())
                    .build();
        }
    }

    private String extractSql(String response) {
        // 清理 LLM 返回的 markdown 格式
        String cleaned = response.trim();
        if (cleaned.startsWith("```sql")) {
            cleaned = cleaned.substring(6);
        }
        if (cleaned.endsWith("```")) {
            cleaned = cleaned.substring(0, cleaned.length() - 3);
        }
        return cleaned.trim();
    }
}

4.5 SQL 安全校验

这是最重要的部分——防止 SQL 注入和误操作:

@Component
public class SqlValidator {

    @Value("${security.allowed-operations}")
    private List<String> allowedOperations;

    public void validate(String sql) {
        // 1. 检查危险关键字
        String[] dangerousKeywords = {"DROP", "TRUNCATE", "ALTER", "DELETE", "UPDATE"};
        String upperSql = sql.trim().toUpperCase();

        for (String keyword : dangerousKeywords) {
            if (upperSql.contains(keyword)) {
                throw new SecurityException("不允许的SQL操作: 包含危险关键字 " + keyword);
            }
        }

        // 2. 解析 SQL 类型
        try {
            Statement stmt = CCJSqlParserUtil.parse(sql);

            if (stmt instanceof Delete) {
                throw new SecurityException("不允许的SQL操作: DELETE");
            }
            if (stmt instanceof Drop) {
                throw new SecurityException("不允许的SQL操作: DROP");
            }

        } catch (SecurityException e) {
            throw e;
        } catch (Exception e) {
            // 解析失败,只允许 SELECT
            if (!upperSql.startsWith("SELECT")) {
                throw new SecurityException("SQL解析失败,只允许SELECT语句");
            }
        }
    }
}

两层防护:

  1. 关键字检测:直接拦截 DROP、DELETE 等危险操作
  2. SQL 解析:用 JSqlParser 解析 SQL 类型,确保只有 SELECT

4.6 对话管理

支持多轮对话的关键是维护上下文:

@Service
public class ConversationService {

    private final ConcurrentHashMap<String, Session> sessions = new ConcurrentHashMap<>();

    @Value("${conversation.max-history}")
    private int maxHistory;  // 保留最近 5 轮

    public String getConversationHistory(String sessionId) {
        Session session = sessions.get(sessionId);
        if (session == null) {
            return "";
        }

        List<Message> recentMessages = session.getRecentMessages(maxHistory);
        StringBuilder history = new StringBuilder();

        for (Message msg : recentMessages) {
            history.append(msg.getRole()).append(": ").append(msg.getContent());
            if (msg.getSql() != null) {
                history.append(" [SQL: ").append(msg.getSql()).append("]");
            }
            history.append("\n");
        }

        return history.toString();
    }
}

这样就能处理这种对话:

用户: 查询所有用户
AI: SELECT * FROM users

用户: 他们的平均年龄是多少?
AI: SELECT AVG(age) FROM users  ← 知道"他们"指的是 users 表

五、运行效果

5.1 单表查询

用户问:"查询所有用户"

生成的 SQL:

SELECT * FROM users

返回结果:

{
  "sessionId": "abc-123",
  "sql": "SELECT * FROM users",
  "success": true,
  "columns": ["id", "name", "email", "age"],
  "rows": [
    {"id": 1, "name": "张三", "email": "zhangsan@example.com", "age": 25},
    {"id": 2, "name": "李四", "email": "lisi@example.com", "age": 30},
    {"id": 3, "name": "王五", "email": "wangwu@example.com", "age": 28}
  ],
  "rowCount": 3,
  "executionTimeMs": 45
}

5.2 跨表查询

用户问:"查询每个用户的订单总金额,按金额从高到低排序"

生成的 SQL:

SELECT u.name, SUM(o.amount) AS total_amount
FROM users u
JOIN orders o ON u.id = o.user_id
GROUP BY u.id, u.name
ORDER BY total_amount DESC

返回结果:

张三  6098.00
王五  1299.00
李四   299.00

LLM 自动识别出需要 JOIN 两张表,还加了 GROUP BY 和 ORDER BY。

5.3 多轮对话

第一轮:

用户: 查询所有用户
AI: SELECT * FROM users

第二轮:

用户: 他们的平均年龄是多少?
AI: SELECT AVG(age) FROM users

LLM 通过历史对话知道"他们"指的是 users 表的用户。

5.4 安全拦截

用户问:"删除所有用户"

返回:

{
  "success": false,
  "error": "ERROR: 无法生成SQL"
}

成功拦截危险操作。

六、MCP 协议支持

6.1 什么是 MCP

MCP (Model Context Protocol) 是一个标准化的工具协议,让 AI 能调用外部工具。

简单说就是定义一套接口:

  • GET /mcp/tools - 获取可用工具列表
  • POST /mcp/tools/call - 调用工具

6.2 实现 MCP Server

定义两个工具:

@Component
public class McpServer {

    public List<McpToolDefinition> getTools() {
        List<McpToolDefinition> tools = new ArrayList<>();

        // 工具1:查询数据库
        tools.add(McpToolDefinition.builder()
                .name("query_database")
                .description("使用自然语言查询数据库")
                .inputSchema(Map.of(
                    "type", "object",
                    "properties", Map.of(
                        "question", Map.of(
                            "type", "string",
                            "description", "用户的自然语言问题"
                        )
                    ),
                    "required", List.of("question")
                ))
                .build());

        // 工具2:获取 Schema
        tools.add(McpToolDefinition.builder()
                .name("get_schema")
                .description("获取数据库Schema信息")
                .inputSchema(Map.of("type", "object"))
                .build());

        return tools;
    }

    public Object executeTool(String toolName, Map<String, Object> arguments) {
        return switch (toolName) {
            case "query_database" -> executeQueryDatabase(arguments);
            case "get_schema" -> executeGetSchema();
            default -> throw new IllegalArgumentException("未知的工具: " + toolName);
        };
    }
}

6.3 测试 MCP 接口

获取工具列表:

curl http://localhost:8080/mcp/tools

返回:

[
  {
    "name": "query_database",
    "description": "使用自然语言查询数据库",
    "inputSchema": {...}
  },
  {
    "name": "get_schema",
    "description": "获取数据库Schema信息",
    "inputSchema": {...}
  }
]

调用工具:

curl -X POST http://localhost:8080/mcp/tools/call \
  -H "Content-Type: application/json" \
  -d '{
    "tool": "query_database",
    "arguments": {
      "question": "有多少个用户?"
    }
  }'

返回:

{
  "success": true,
  "content": {
    "sql": "SELECT COUNT(*) FROM users",
    "rows": [{"COUNT(*)": 3}],
    "rowCount": 1
  }
}