测试
我们坚信质量保证 (QA) 流程的重要性,我们希望确保添加到 OpenBB 平台的所有扩展都具有高质量。
为了确保这一点,我们提供了一套 QA 工具供您测试您的工作。
主要地,我们有半自动化创建单元测试和集成测试的工具。
QA 工具仍在开发中,我们正在不断改进它们。
单元测试
每个 Fetcher
都配备了一个 test
方法,该方法将确保它被正确实现,返回预期的数据,所有类型都正确,并且数据有效。
要为您的 Fetchers 创建单元测试,您可以运行以下命令:
python openbb_platform/providers/tests/utils/unit_tests_generator.py
从仓库根目录运行此文件,并且必须存在
tests
文件夹才能生成测试。
自动单元测试生成将为给定提供商中可用的所有获取器添加单元测试。
要记录单元测试,您可以运行以下命令:
pytest <path_to_the_unit_test_file> --record=all
备注
有时需要手动干预。例如,调整顶级导入或为给定获取器添加特定参数。
集成测试
集成测试比单元测试更复杂,因为我们想要测试 Python 接口和 API 接口。为此,我们有两个脚本来帮助您生成集成测试。
Python 接口集成测试
python openbb_platform/openbb/package/integration_test_generator.py
此脚本将为 Python 接口生成集成测试。
API 接口集成测试
python openbb_platform/openbb/package/integration_test_api_generator.py
此脚本将为 REST API 接口生成集成测试。
测试最佳实践
1. 测试结构
import pytest
from unittest.mock import Mock, patch
from openbb_core.provider.abstract.fetcher import Fetcher
class TestMyDataFetcher:
"""测试 MyDataFetcher 类"""
def setup_method(self):
"""每个测试方法前的设置"""
self.fetcher = MyDataFetcher()
def teardown_method(self):
"""每个测试方法后的清理"""
pass
def test_transform_query_success(self):
"""测试查询转换成功情况"""
query_params = {"symbol": "AAPL", "period": "1d"}
result = self.fetcher.transform_query(query_params)
assert isinstance(result, dict)
assert "symbol" in result
assert result["symbol"] == "AAPL"
def test_transform_query_invalid_symbol(self):
"""测试无效股票代码的错误处理"""
query_params = {"symbol": "", "period": "1d"}
with pytest.raises(ValueError, match="股票代码不能为空"):
self.fetcher.transform_query(query_params)
@patch('requests.get')
def test_extract_data_success(self, mock_get):
"""测试数据提取成功情况"""
# 模拟 API 响应
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [{"symbol": "AAPL", "price": 150.0}]
}
mock_get.return_value = mock_response
query = {"symbol": "AAPL"}
credentials = {"api_key": "test_key"}
result = self.fetcher.extract_data(query, credentials)
assert isinstance(result, dict)
assert "data" in result
mock_get.assert_called_once()
@patch('requests.get')
def test_extract_data_api_error(self, mock_get):
"""测试 API 错误处理"""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
query = {"symbol": "INVALID"}
credentials = {"api_key": "test_key"}
with pytest.raises(Exception, match="API 请求失败"):
self.fetcher.extract_data(query, credentials)
def test_transform_data_success(self):
"""测试数据转换成功情况"""
raw_data = {
"data": [
{"symbol": "AAPL", "price": 150.0, "date": "2023-01-01"},
{"symbol": "AAPL", "price": 151.0, "date": "2023-01-02"}
]
}
query = {"symbol": "AAPL"}
result = self.fetcher.transform_data(query, raw_data)
assert isinstance(result, list)
assert len(result) == 2
assert all("symbol" in item for item in result)
def test_transform_data_empty_data(self):
"""测试空数据处理"""
raw_data = {"data": []}
query = {"symbol": "AAPL"}
result = self.fetcher.transform_data(query, raw_data)
assert isinstance(result, list)
assert len(result) == 0
2. 参数化测试
@pytest.mark.parametrize("symbol,expected", [
("AAPL", True),
("MSFT", True),
("INVALID", False),
("", False),
])
def test_validate_symbol(symbol, expected):
"""参数化测试股票代码验证"""
result = validate_symbol(symbol)
assert result == expected
@pytest.mark.parametrize("period,interval,expected_calls", [
("1d", "1m", 1440), # 1天,1分钟间隔
("1d", "5m", 288), # 1天,5分钟间隔
("1w", "1h", 168), # 1周,1小时间隔
])
def test_data_points_calculation(period, interval, expected_calls):
"""参数化测试数据点计算"""
result = calculate_data_points(period, interval)
assert result == expected_calls
3. 异步测试
import asyncio
import pytest
@pytest.mark.asyncio
async def test_async_data_fetch():
"""测试异步数据获取"""
fetcher = AsyncDataFetcher()
result = await fetcher.fetch_data_async("AAPL")
assert result is not None
assert "symbol" in result
assert result["symbol"] == "AAPL"
@pytest.mark.asyncio
async def test_concurrent_requests():
"""测试并发请求"""
fetcher = AsyncDataFetcher()
symbols = ["AAPL", "MSFT", "GOOGL"]
tasks = [fetcher.fetch_data_async(symbol) for symbol in symbols]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert all(result is not None for result in results)
4. 集成测试示例
@pytest.mark.integration
class TestOpenBBIntegration:
"""OpenBB 平台集成测试"""
def test_equity_price_historical(self):
"""测试股票历史价格获取"""
from openbb import obb
result = obb.equity.price.historical(
symbol="AAPL",
provider="yfinance"
)
assert result is not None
df = result.to_dataframe()
assert not df.empty
assert "close" in df.columns
assert "volume" in df.columns
def test_economy_indicators(self):
"""测试经济指标获取"""
from openbb import obb
result = obb.economy.indicators(
symbol="CPI",
country="US",
provider="econdb"
)
assert result is not None
df = result.to_dataframe()
assert not df.empty
assert "value" in df.columns
@pytest.mark.slow
def test_large_data_request(self):
"""测试大数据量请求"""
from openbb import obb
result = obb.equity.price.historical(
symbol="SPY",
start_date="2020-01-01",
end_date="2023-12-31",
provider="yfinance"
)
df = result.to_dataframe()
assert len(df) > 1000 # 应该有足够的数据点
测试配置
pytest.ini 配置
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
-v
--tb=short
--strict-markers
--disable-warnings
markers =
unit: 单元测试
integration: 集成测试
slow: 慢速测试
api: API 测试
requires_credentials: 需要凭据的测试
conftest.py 配置
import pytest
from unittest.mock import Mock
@pytest.fixture
def mock_api_response():
"""模拟 API 响应"""
response = Mock()
response.status_code = 200
response.json.return_value = {
"data": [
{"symbol": "AAPL", "price": 150.0, "date": "2023-01-01"}
]
}
return response
@pytest.fixture
def sample_credentials():
"""示例凭据"""
return {
"api_key": "test_api_key",
"secret": "test_secret"
}
@pytest.fixture
def sample_query_params():
"""示例查询参数"""
return {
"symbol": "AAPL",
"period": "1d",
"interval": "1m"
}
@pytest.fixture(scope="session")
def test_database():
"""测试数据库设置"""
# 设置测试数据库
db = create_test_database()
yield db
# 清理测试数据库
cleanup_test_database(db)
测试覆盖率
生成覆盖率报告
# 运行测试并生成覆盖率报告
pytest --cov=openbb_platform --cov-report=html --cov-report=term
# 只显示缺失覆盖的行
pytest --cov=openbb_platform --cov-report=term-missing
# 设置最低覆盖率要求
pytest --cov=openbb_platform --cov-fail-under=80
覆盖率配置
[tool:pytest]
addopts = --cov=openbb_platform --cov-report=term-missing --cov-fail-under=80
[coverage:run]
source = openbb_platform
omit =
*/tests/*
*/test_*
*/__init__.py
*/conftest.py
[coverage:report]
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
持续集成测试
GitHub Actions 配置
name: Tests
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, 3.10, 3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Run unit tests
run: |
pytest tests/unit --cov=openbb_platform
- name: Run integration tests
run: |
pytest tests/integration -m "not slow"
env:
API_KEY: ${{ secrets.API_KEY }}
- name: Upload coverage reports
uses: codecov/codecov-action@v3
性能测试
import time
import pytest
from memory_profiler import profile
def test_import_time():
"""测试导入时间"""
start_time = time.time()
import openbb
import_time = time.time() - start_time
# 导入时间应该少于5秒
assert import_time < 5.0, f"导入时间过长: {import_time:.2f}秒"
@profile
def test_memory_usage():
"""测试内存使用"""
from openbb import obb
# 执行一些操作
result = obb.equity.price.historical("AAPL", provider="yfinance")
df = result.to_dataframe()
# 内存使用应该在合理范围内
assert len(df) > 0
@pytest.mark.benchmark
def test_data_fetch_performance(benchmark):
"""基准测试数据获取性能"""
from openbb import obb
def fetch_data():
return obb.equity.price.quote("AAPL", provider="yfinance")
result = benchmark(fetch_data)
assert result is not None
测试最佳实践总结
- 测试命名:使用描述性的测试名称
- 测试隔离:每个测试应该独立运行
- 模拟外部依赖:使用 mock 避免真实 API 调用
- 参数化测试:测试多种输入情况
- 异常测试:测试错误处理路径
- 性能测试:监控关键操作的性能
- 覆盖率监控:保持高测试覆盖率
- 持续集成:自动化测试流程
# 完整的测试示例
class TestCompleteExample:
"""完整的测试示例"""
@pytest.fixture(autouse=True)
def setup(self):
"""自动设置"""
self.fetcher = MyDataFetcher()
def test_happy_path(self):
"""测试正常流程"""
# 准备
query = {"symbol": "AAPL"}
# 执行
result = self.fetcher.process(query)
# 验证
assert result is not None
assert "data" in result
def test_error_handling(self):
"""测试错误处理"""
with pytest.raises(ValueError):
self.fetcher.process({"symbol": ""})
@pytest.mark.parametrize("symbol", ["AAPL", "MSFT", "GOOGL"])
def test_multiple_symbols(self, symbol):
"""测试多个股票代码"""
result = self.fetcher.process({"symbol": symbol})
assert result["symbol"] == symbol
@patch('requests.get')
def test_with_mock(self, mock_get):
"""使用模拟的测试"""
mock_get.return_value.json.return_value = {"data": []}
result = self.fetcher.fetch_external_data()
assert result == {"data": []}