Open Source, Open Future!
  menu
120 文章
ღゝ◡╹)ノ❤️

量化系列番外篇:给量化系统加个 AI 智能体助手

一、写在前面

做量化研究最烦的不是写策略,是查数据。

上次写完量化系列,数据都存在 Parquet 文件里。想查点东西得这样:

# 查茅台最近20天的收盘价
df = pd.read_parquet('data/stock_data.parquet')
maotai = df[df['symbol'] == '600519.SH'].tail(20)

# 查 RSI 超过70的股票
features = pd.read_parquet('data/features.parquet')
high_rsi = features[features['rsi'] > 70]

每次都要写代码,挺麻烦的。能不能直接问 AI呢?

比如:

  • 「茅台最近20天的收盘价是多少?」
  • 「RSI 超过70的股票有哪些?」
  • 「数据集里一共有多少只股票?」

这个想法挺有意思。正好最近在研究 LLM 和 AI Agent,就试着做了一个。

本文记录一下实现过程,涉及:

  • AI Agent(LLM 生成代码 → 执行 → 返回结果)
  • 通义千问 API
  • 代码安全执行
  • Parquet 数据查询

二、技术选型

最终选了:

  • LLM: 通义千问 qwen-max(国内访问快,中文理解好)
  • API 调用: urllib.request(标准库,无需额外依赖)
  • 数据处理: pandas + Parquet
  • 代码执行: exec() + 沙箱环境

三、核心实现

3.1 整体架构

用户问题 → Schema Provider → Prompt 构建 → 通义千问 API → 生成代码 → 安全执行 → 返回结果

3.2 Schema 自动获取

让 LLM 知道数据结构,真正自动读取 Parquet 文件的 Schema:

class SchemaProvider:
    """自动获取数据文件的 Schema"""

    def __init__(self, data_dir='data'):
        self.data_dir = data_dir
        self.schema_cache = {}

    def get_schema(self) -> str:
        if 'schema' not in self.schema_cache:
            self.schema_cache['schema'] = self._load_schema()
        return self.schema_cache['schema']

    def _load_schema(self) -> str:
        """自动读取 Parquet 文件的 Schema"""
        schema = []

        # 列出目录下的 parquet 文件
        files = os.listdir(self.data_dir)
        parquet_files = [f for f in files if f.endswith('.parquet')]

        schema.append(f"数据目录: {self.data_dir}")
        schema.append(f"找到 {len(parquet_files)} 个 parquet 文件:")
        schema.append("")

        # 遍历每个 parquet 文件,读取真实的 Schema
        for f in sorted(parquet_files):
            filepath = os.path.join(self.data_dir, f)
            size = os.path.getsize(filepath)

            schema.append(f"{f} ({size / 1024 / 1024:.1f} MB):")

            # 读取文件来获取列信息
            df = pd.read_parquet(filepath)
            row_count = len(df)

            schema.append(f"  行数: {row_count:,}")
            schema.append(f"  字段 ({len(df.columns)} 个):")

            # 列出所有字段及其类型
            for col in df.columns:
                dtype = df[col].dtype
                schema.append(f"    - {col} ({dtype})")

            # 添加说明
            if 'stock_data' in f:
                schema.append("  说明: 股票日线行情数据,股票代码格式如 600519.SH(茅台)")
            elif 'features' in f:
                schema.append("  说明: 技术指标和因子数据")

            schema.append("")

        return "\n".join(schema)

生成的 Schema 示例(真实数据):

数据目录: /home/devuser/data/yc/量化/data
找到 3 个 parquet 文件:

features.parquet (426.1 MB):
  行数: 1,537,663
  字段 (53 个):
    - date (datetime64[us])
    - symbol (str)
    - open (float64)
    - close (float64)
    - return_1d (float64)
    - return_5d (float64)
    - ma_5 (float64)
    - ma_20 (float64)
    - rsi (float64)
    - macd (float64)
    ... (共53个字段)
  说明: 技术指标和因子数据

stock_data.parquet (38.5 MB):
  行数: 1,589,404
  字段 (10 个):
    - date (datetime64[us])
    - symbol (str)
    - open (float64)
    - high (float64)
    - low (float64)
    - close (float64)
    - volume (float64)
    - amount (float64)
    - pct_chg (float64)
    - change (float64)
  说明: 股票日线行情数据,股票代码格式如 600519.SH(茅台)

关键点:

  • 自动读取:使用pd.read_parquet() 读取文件,获取真实的列名和类型
  • 缓存机制:避免重复读取大文件
  • 完整信息:包含文件大小、行数、字段数量、每个字段的类型

3.3 通义千问 API 调用

使用标准库 urllib.request,无需安装 SDK:

class QwenClient:
    """通义千问 API 客户端"""

    def __init__(self, api_key: str):
        self.api_key = api_key
        self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"

    def generate(self, prompt: str, model: str = "qwen-max") -> str:
        """调用通义千问 API 生成文本"""
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }

        data = {
            "model": model,
            "input": {
                "messages": [{"role": "user", "content": prompt}]
            },
            "parameters": {
                "result_format": "message"
            }
        }

        req = urllib.request.Request(
            self.api_url,
            data=json.dumps(data).encode('utf-8'),
            headers=headers,
            method='POST'
        )

        with urllib.request.urlopen(req, timeout=60) as response:
            result = json.loads(response.read().decode('utf-8'))
            if result.get('output') and result['output'].get('choices'):
                return result['output']['choices'][0]['message']['content']
            else:
                return f"# ERROR: API 返回格式异常"

3.4 Prompt 工程

关键是告诉 LLM:pandas 已经预导入,不要写 import 语句:

CODE_GEN_TEMPLATE = """
你是一个专业的数据分析师。根据用户的自然语言问题和数据Schema,生成准确的 pandas 代码。

数据Schema:
{schema}

用户问题: {question}

要求:
1. 只生成查询代码,不要生成修改数据的代码(不要 to_parquet、to_csv 等)
2. 只返回 Python 代码本身,不要包含任何解释或markdown格式
3. pandas 已经导入为 pd,numpy 已经导入为 np,不要写 import 语句
4. 使用 pd.read_parquet() 读取文件,路径格式: 'data/stock_data.parquet'
5. 股票代码格式是 '600519.SH'(茅台)、'000001.SZ'(平安银行)
6. 代码最后必须用 print() 输出结果
7. 不要使用 ```python 这样的 markdown 格式,直接返回纯代码

示例:
问题: 茅台最近20天的收盘价
代码:
df = pd.read_parquet('data/stock_data.parquet')
result = df[df['symbol'] == '600519.SH'].tail(20)[['date', 'close']]
print(result)

Python代码:
"""

3.5 安全代码执行

三层防护:关键字检测 + 沙箱环境 + 工作目录控制

class PythonExecutor:
    """安全执行 Python 代码"""

    DANGEROUS_KEYWORDS = [
        'import os', 'import sys', 'import subprocess',
        'exec(', 'eval(',
        'open(', 'write', 'delete', 'remove', 'rmdir',
        'to_parquet', 'to_csv', 'to_excel'
    ]

    def execute(self, code: str, data_dir: str = '.'):
        """执行代码并返回结果"""
        # 1. 安全检查
        self._validate_code(code)

        # 2. 切换到正确的工作目录
        original_dir = os.getcwd()
        work_dir = os.path.dirname(data_dir)
        os.chdir(work_dir)

        try:
            output_buffer = io.StringIO()

            with redirect_stdout(output_buffer):
                import numpy as np
                import datetime

                # 预导入 pandas、numpy、datetime,限制可用的内置函数
                safe_globals = {
                    '__builtins__': {
                        'print': print, 'len': len, 'range': range,
                        'str': str, 'int': int, 'float': float,
                        'list': list, 'dict': dict, 'tuple': tuple,
                        'sum': sum, 'min': min, 'max': max,
                        'sorted': sorted, 'enumerate': enumerate,
                        # ... 其他安全的内置函数
                    },
                    'pd': pd,
                    'np': np,
                    'datetime': datetime,
                }
                exec(code, safe_globals)
        finally:
            os.chdir(original_dir)

        return {
            "success": True,
            "output": output_buffer.getvalue()
        }

    def _validate_code(self, code: str):
        """安全检查"""
        code_lower = code.lower()
        for keyword in self.DANGEROUS_KEYWORDS:
            if keyword.lower() in code_lower:
                raise SecurityError(f"不允许的操作: 包含危险关键字 {keyword}")

3.6 AI Agent 核心逻辑

class DataQueryAgent:
    """数据查询 AI Agent"""

    def __init__(self, data_dir='data', api_key=None):
        self.data_dir = data_dir
        self.schema_provider = SchemaProvider(data_dir)
        self.python_executor = PythonExecutor()
        self.qwen_client = QwenClient(api_key) if api_key else None

    def process_question(self, question: str):
        """处理用户问题"""
        # 1. 获取 Schema
        schema = self.schema_provider.get_schema()

        # 2. 构建 Prompt
        prompt = build_prompt(schema, "", question)

        # 3. 调用 LLM 生成代码
        response = self.qwen_client.generate(prompt)
        code = self._extract_code(response)

        # 4. 执行代码
        result = self.python_executor.execute(code, self.data_dir)

        return result

四、运行效果

4.1 测试环境

  • 数据规模:stock_data.parquet (38.5 MB, 158万条), features.parquet (426.1 MB, 39个特征)
  • 股票数量:797只
  • 时间范围:2024年全年数据

4.2 测试案例

测试1:茅台最近20天收盘价

用户问:「茅台最近20天的收盘价?」

通义千问生成的代码:

df = pd.read_parquet('data/stock_data.parquet')
result = df[df['symbol'] == '600519.SH'].tail(20)[['date', 'close']]
print(result)

实际返回结果:

             date    close
956550 2024-12-04  1496.60
956551 2024-12-05  1487.74
956552 2024-12-06  1497.59
...
956568 2024-12-30  1525.00
956569 2024-12-31  1524.00

✅ 成功:准确生成代码,成功查询到2024年12月的20天收盘价数据。

测试2:统计股票数量

用户问:「数据集里一共有多少只股票?」

通义千问生成的代码:

df = pd.read_parquet('data/stock_data.parquet')
unique_stocks = df['symbol'].nunique()
print(unique_stocks)

实际返回结果:

797

✅ 成功:正确使用 nunique() 方法统计唯一股票数量。

六、总结

给量化系统加了个 AI 智能体助手,核心功能都实现了:

已实现:

  • ✅ AI Agent(LLM 生成代码 → 执行 → 返回结果)
  • ✅ Schema 自动获取(扫描数据目录)
  • ✅ 通义千问 API 集成(零依赖,使用标准库)
  • ✅ 安全代码执行(关键字检测、沙箱环境、工作目录控制)
  • ✅ 真实数据测试(158万条行情数据 + 39个技术指标)

对量化研究来说,这个工具确实提升了效率。以前查个数据要:

  1. 打开 Jupyter Notebook
  2. 写 pandas 代码
  3. 运行、调试
  4. 看结果

现在直接问就行,2-3秒出结果。

运行方式:

# 设置 API Key
export DASHSCOPE_API_KEY=your_api_key

# 运行
python3 query_agent.py