544 lines
21 KiB
Python
544 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
DMS合规性测试工具 - FastAPI版本API服务器
|
||
提供自动生成的交互式API文档
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import logging
|
||
import datetime
|
||
import traceback
|
||
from pathlib import Path
|
||
from typing import List, Optional, Dict, Any, Union
|
||
import unicodedata
|
||
import html
|
||
|
||
# FastAPI imports
|
||
from fastapi import FastAPI, HTTPException, BackgroundTasks, status
|
||
from fastapi.responses import JSONResponse, FileResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||
import uvicorn
|
||
|
||
# PDF generation libraries - with fallback
|
||
try:
|
||
from reportlab.lib import colors
|
||
from reportlab.lib.pagesizes import A4
|
||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, HRFlowable
|
||
from reportlab.pdfbase import pdfmetrics
|
||
from reportlab.pdfbase.ttfonts import TTFont
|
||
reportlab_available = True
|
||
except ImportError:
|
||
reportlab_available = False
|
||
|
||
# Project-specific imports
|
||
from ddms_compliance_suite.api_caller.caller import APICallDetail
|
||
from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary
|
||
from ddms_compliance_suite.input_parser.parser import ParsedAPISpec
|
||
from ddms_compliance_suite.utils.response_utils import extract_data_for_validation
|
||
from ddms_compliance_suite.utils.data_generator import DataGenerator
|
||
|
||
# Configure logging
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# FastAPI app instance
|
||
app = FastAPI(
|
||
title="DMS合规性测试工具 API",
|
||
description="""
|
||
DMS合规性测试工具 FastAPI版本
|
||
|
||
这是一个用于API合规性测试的工具,支持:
|
||
|
||
YAPI规范测试 - 基于YAPI定义文件的测试
|
||
Swagger/OpenAPI测试 - 基于OpenAPI规范的测试
|
||
DMS服务发现测试 - 动态发现DMS服务的API进行测试
|
||
分页支持 - 支持大量API的分页获取,避免内存溢出
|
||
PDF报告生成 - 生成详细的测试报告
|
||
LLM集成 - 支持大语言模型辅助生成测试数据
|
||
|
||
主要特性
|
||
|
||
- 🚀 高性能: 基于FastAPI,支持异步处理
|
||
- 📊 分页支持: 解决大量API节点的内存问题
|
||
- 📝 自动文档: 自动生成交互式API文档
|
||
- 🔧 灵活配置: 支持多种测试配置选项
|
||
- 📈 详细报告: 生成PDF和JSON格式的测试报告
|
||
""",
|
||
version="1.0.0",
|
||
docs_url="/docs", # Swagger UI
|
||
redoc_url="/redoc", # ReDoc
|
||
)
|
||
|
||
# Add CORS middleware
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该限制具体域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Pydantic models for request/response
|
||
class TestConfig(BaseModel):
|
||
"""测试配置模型"""
|
||
|
||
# API定义源 (三选一)
|
||
yapi: Optional[str] = Field(None, description="YAPI定义文件路径", exclude=True)
|
||
swagger: Optional[str] = Field(None, description="Swagger/OpenAPI定义文件路径", exclude=True)
|
||
dms: Optional[str] = Field("./assets/doc/dms/domain.json", description="DMS服务发现的domain mapping文件路径", example="./assets/doc/dms/domain.json")
|
||
# 基本配置
|
||
base_url: str = Field("https://www.dev.ideas.cnpc/", description="API基础URL", example="https://www.dev.ideas.cnpc/")
|
||
|
||
# 分页配置
|
||
page_size: int = Field(10, description="DMS API分页大小,默认10。较小的值可以减少内存使用", ge=1, le=10000)
|
||
page_no: int = Field(1, description="起始页码,从1开始。可用于断点续传或跳过前面的页面", ge=1)
|
||
fetch_all_pages: bool = Field(False, description="是否获取所有页面。True=获取所有数据,False=只获取指定页面")
|
||
|
||
# 过滤选项
|
||
strictness_level: str = Field("CRITICAL", description="测试严格等级", pattern="^(CRITICAL|HIGH|MEDIUM|LOW)$")
|
||
|
||
@field_validator('base_url')
|
||
@classmethod
|
||
def validate_base_url(cls, v):
|
||
if not v.startswith(('http://', 'https://')):
|
||
raise ValueError('base_url must start with http:// or https://')
|
||
return v
|
||
|
||
@model_validator(mode='before')
|
||
@classmethod
|
||
def validate_api_source(cls, values):
|
||
"""验证API定义源,确保三选一"""
|
||
if isinstance(values, dict):
|
||
api_sources = [values.get('yapi'), values.get('swagger'), values.get('dms')]
|
||
non_none_sources = [s for s in api_sources if s is not None]
|
||
if len(non_none_sources) > 1:
|
||
raise ValueError('只能选择一个API定义源:yapi、swagger或dms')
|
||
if len(non_none_sources) == 0:
|
||
raise ValueError('必须提供一个API定义源:yapi、swagger或dms')
|
||
return values
|
||
|
||
class PaginationInfo(BaseModel):
|
||
"""分页信息模型"""
|
||
page_size: int = Field(description="页面大小")
|
||
page_no_start: int = Field(description="起始页码")
|
||
total_pages: int = Field(description="总页数")
|
||
total_records: int = Field(description="总记录数")
|
||
pages_fetched: int = Field(description="已获取页数")
|
||
current_page: int = Field(description="当前页码")
|
||
|
||
class TestResponse(BaseModel):
|
||
"""测试响应模型"""
|
||
status: str = Field(description="测试状态", example="completed")
|
||
message: str = Field(description="状态消息")
|
||
report_directory: str = Field(description="报告目录路径")
|
||
summary: Dict[str, Any] = Field(description="测试摘要信息")
|
||
pagination: Optional[PaginationInfo] = Field(None, description="分页信息(仅DMS测试时返回)")
|
||
|
||
class ErrorResponse(BaseModel):
|
||
"""错误响应模型"""
|
||
status: str = Field("error", description="错误状态")
|
||
message: str = Field(description="错误消息")
|
||
traceback: Optional[str] = Field(None, description="错误堆栈跟踪")
|
||
|
||
# Global variable to store running tasks
|
||
running_tasks: Dict[str, Dict[str, Any]] = {}
|
||
|
||
@app.get("/",
|
||
summary="健康检查",
|
||
description="检查API服务器是否正常运行",
|
||
response_model=Dict[str, str])
|
||
async def health_check():
|
||
"""健康检查端点,用于Docker健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"service": "DMS Compliance API Server (FastAPI)",
|
||
"version": "2.0.0",
|
||
"docs_url": "/docs",
|
||
"redoc_url": "/redoc"
|
||
}
|
||
|
||
@app.get("/info",
|
||
summary="服务信息",
|
||
description="获取API服务器的详细信息",
|
||
response_model=Dict[str, Any])
|
||
async def get_info():
|
||
"""获取服务器信息"""
|
||
return {
|
||
"service": "DMS Compliance API Server",
|
||
"version": "2.0.0",
|
||
"framework": "FastAPI",
|
||
"features": [
|
||
"YAPI规范测试",
|
||
"Swagger/OpenAPI测试",
|
||
"DMS服务发现测试",
|
||
"分页支持",
|
||
"PDF报告生成",
|
||
"LLM集成",
|
||
"自动API文档"
|
||
],
|
||
"endpoints": {
|
||
"health": "/",
|
||
"info": "/info",
|
||
"run_tests": "/run",
|
||
"docs": "/docs",
|
||
"redoc": "/redoc"
|
||
},
|
||
"reportlab_available": reportlab_available
|
||
}
|
||
|
||
# Import the test logic from the original Flask version
|
||
def run_tests_logic(config: dict):
|
||
"""
|
||
Main logic for running tests, adapted from the original Flask version.
|
||
"""
|
||
try:
|
||
if config.get('verbose'):
|
||
logging.getLogger('ddms_compliance_suite').setLevel(logging.DEBUG)
|
||
logger.setLevel(logging.DEBUG)
|
||
logger.debug("Verbose logging enabled.")
|
||
|
||
if not any(k in config for k in ['yapi', 'swagger', 'dms']):
|
||
raise ValueError("An API definition source is required: --yapi, --swagger, or --dms")
|
||
|
||
if sum(k in config for k in ['yapi', 'swagger', 'dms']) > 1:
|
||
raise ValueError("API definition sources are mutually exclusive.")
|
||
|
||
# Setup output directory with timestamp
|
||
base_output_dir = Path(config.get('output', './test_reports'))
|
||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
output_directory = base_output_dir / timestamp
|
||
output_directory.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"Test reports will be saved to: {output_directory.resolve()}")
|
||
print(f"config{config}")
|
||
# Initialize the orchestrator
|
||
orchestrator = APITestOrchestrator(
|
||
base_url=config['base_url'],
|
||
custom_test_cases_dir=config.get('custom_test_cases_dir'),
|
||
llm_api_key=config.get('llm_api_key'),
|
||
llm_base_url=config.get('llm_base_url'),
|
||
llm_model_name=config.get('llm_model_name'),
|
||
use_llm_for_request_body=config.get('use_llm_for_request_body', False),
|
||
use_llm_for_path_params=config.get('use_llm_for_path_params', False),
|
||
use_llm_for_query_params=config.get('use_llm_for_query_params', False),
|
||
use_llm_for_headers=config.get('use_llm_for_headers', False),
|
||
output_dir=str(output_directory),
|
||
stages_dir=config.get('stages_dir'),
|
||
strictness_level=config.get('strictness_level', 'CRITICAL'),
|
||
ignore_ssl=config.get('ignore_ssl', False)
|
||
)
|
||
|
||
test_summary: Optional[TestSummary] = None
|
||
parsed_spec: Optional[ParsedAPISpec] = None
|
||
pagination_info: Dict[str, Any] = {}
|
||
|
||
if 'yapi' in config:
|
||
logger.info(f"Running tests from YAPI file: {config['yapi']}")
|
||
test_summary, parsed_spec = orchestrator.run_tests_from_yapi(
|
||
yapi_file_path=config['yapi'],
|
||
categories=config.get('categories'),
|
||
custom_test_cases_dir=config.get('custom_test_cases_dir')
|
||
)
|
||
elif 'swagger' in config:
|
||
logger.info(f"Running tests from Swagger file: {config['swagger']}")
|
||
test_summary, parsed_spec = orchestrator.run_tests_from_swagger(
|
||
swagger_file_path=config['swagger'],
|
||
tags=config.get('tags'),
|
||
custom_test_cases_dir=config.get('custom_test_cases_dir')
|
||
)
|
||
elif 'dms' in config:
|
||
logger.info(f"Running tests from DMS service discovery: {config['dms']}")
|
||
test_summary, parsed_spec, pagination_info = orchestrator.run_tests_from_dms(
|
||
domain_mapping_path=config['dms'],
|
||
categories=config.get('categories'),
|
||
custom_test_cases_dir=config.get('custom_test_cases_dir'),
|
||
page_size=config.get('page_size', 1000),
|
||
page_no_start=config.get('page_no', 1),
|
||
fetch_all_pages=config.get('fetch_all_pages', True)
|
||
)
|
||
|
||
if not parsed_spec:
|
||
raise RuntimeError("Failed to parse the API specification.")
|
||
|
||
if test_summary and config.get('stages_dir') and parsed_spec:
|
||
logger.info(f"Executing API test stages from directory: {config['stages_dir']}")
|
||
orchestrator.run_stages_from_spec(parsed_spec, test_summary)
|
||
|
||
if test_summary:
|
||
# Save main summary
|
||
main_report_file_path = output_directory / "summary.json"
|
||
with open(main_report_file_path, 'w', encoding='utf-8') as f:
|
||
f.write(test_summary.to_json(pretty=True))
|
||
|
||
# Save API call details
|
||
api_calls_filename = "api_call_details.md"
|
||
save_api_call_details_to_markdown(
|
||
orchestrator.get_api_call_details(),
|
||
str(output_directory),
|
||
filename=api_calls_filename
|
||
)
|
||
|
||
failed_count = getattr(test_summary, 'endpoints_failed', 0) + getattr(test_summary, 'test_cases_failed', 0)
|
||
error_count = getattr(test_summary, 'endpoints_error', 0) + getattr(test_summary, 'test_cases_error', 0)
|
||
|
||
result = {
|
||
"status": "completed",
|
||
"message": "Tests finished." if failed_count == 0 and error_count == 0 else "Tests finished with failures or errors.",
|
||
"report_directory": str(output_directory.resolve()),
|
||
"summary": test_summary.to_dict()
|
||
}
|
||
|
||
# 如果有分页信息,添加到返回结果中
|
||
if pagination_info:
|
||
result["pagination"] = pagination_info
|
||
|
||
return result
|
||
else:
|
||
raise RuntimeError("Test execution failed to produce a summary.")
|
||
|
||
except Exception as e:
|
||
logger.error(f"An unexpected error occurred during test execution: {e}", exc_info=True)
|
||
return {
|
||
"status": "error",
|
||
"message": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}
|
||
|
||
def save_api_call_details_to_markdown(api_call_details: List[APICallDetail], output_dir: str, filename: str = "api_call_details.md"):
|
||
"""Save API call details to markdown file"""
|
||
try:
|
||
output_path = Path(output_dir) / filename
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write("# API调用详情\n\n")
|
||
|
||
for i, detail in enumerate(api_call_details, 1):
|
||
f.write(f" {i}. {detail.endpoint_name}\n\n")
|
||
f.write(f"请求URL: `{detail.request_url}`\n\n")
|
||
f.write(f"请求方法: `{detail.request_method}`\n\n")
|
||
|
||
if detail.request_headers:
|
||
f.write("请求头:\n```json\n")
|
||
f.write(json.dumps(detail.request_headers, indent=2, ensure_ascii=False))
|
||
f.write("\n```\n\n")
|
||
|
||
if detail.request_body:
|
||
f.write("请求体:\n```json\n")
|
||
f.write(json.dumps(detail.request_body, indent=2, ensure_ascii=False))
|
||
f.write("\n```\n\n")
|
||
|
||
f.write(f"响应状态码: `{detail.response_status_code}`\n\n")
|
||
|
||
if detail.response_headers:
|
||
f.write("响应头:\n```json\n")
|
||
f.write(json.dumps(detail.response_headers, indent=2, ensure_ascii=False))
|
||
f.write("\n```\n\n")
|
||
|
||
if detail.response_body:
|
||
f.write("响应体:\n```json\n")
|
||
f.write(json.dumps(detail.response_body, indent=2, ensure_ascii=False))
|
||
f.write("\n```\n\n")
|
||
|
||
f.write("---\n\n")
|
||
|
||
logger.info(f"API call details saved to: {output_path}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error saving API call details: {e}")
|
||
|
||
@app.post("/run",
|
||
summary="执行API合规性测试",
|
||
description="""
|
||
执行API合规性测试的主要端点。
|
||
|
||
支持三种API定义源:
|
||
- YAPI: 基于YAPI定义文件
|
||
- Swagger/OpenAPI: 基于OpenAPI规范文件
|
||
- DMS: 动态发现DMS服务的API
|
||
|
||
分页支持
|
||
对于DMS测试,支持分页获取API列表,避免内存溢出:
|
||
- `page_size`: 每页获取的API数量(默认1000)
|
||
- 返回详细的分页统计信息
|
||
|
||
LLM集成
|
||
可选择使用大语言模型生成测试数据:
|
||
- 智能生成请求体、路径参数、查询参数等
|
||
- 提高测试覆盖率和数据多样性
|
||
""",
|
||
response_model=TestResponse,
|
||
responses={
|
||
200: {"description": "测试执行成功"},
|
||
400: {"description": "请求参数错误", "model": ErrorResponse},
|
||
500: {"description": "服务器内部错误", "model": ErrorResponse}
|
||
})
|
||
async def run_api_tests(config: TestConfig):
|
||
"""
|
||
执行API合规性测试
|
||
|
||
- config: 测试配置,包含API定义源、测试参数等
|
||
- returns: 测试结果,包含摘要信息和分页信息(如适用)
|
||
"""
|
||
try:
|
||
logger.info(f"Starting test run with configuration: {config.model_dump()}")
|
||
|
||
# Convert Pydantic model to dict for compatibility
|
||
config_dict = config.model_dump(exclude_none=True)
|
||
|
||
# Add hidden parameters with default values
|
||
hidden_defaults = {
|
||
"categories": [],
|
||
"tags": [],
|
||
"ignore_ssl": True,
|
||
"output": "./test_reports",
|
||
"generate_pdf": True,
|
||
"custom_test_cases_dir": "./custom_testcases",
|
||
"stages_dir": "./custom_stages",
|
||
"llm_api_key": "sk-lbGrsUPL1iby86h554FaE536C343435dAa9bA65967A840B2",
|
||
"llm_base_url": "https://aiproxy.petrotech.cnpc/v1",
|
||
"llm_model_name": "deepseek-v3",
|
||
"use_llm_for_request_body": False,
|
||
"use_llm_for_path_params": False,
|
||
"use_llm_for_query_params": False,
|
||
"use_llm_for_headers": False,
|
||
"verbose": False
|
||
}
|
||
|
||
# Merge hidden defaults with config
|
||
config_dict.update(hidden_defaults)
|
||
|
||
result = run_tests_logic(config_dict)
|
||
|
||
if result['status'] == 'error':
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=result
|
||
)
|
||
|
||
return result
|
||
|
||
except ValueError as e:
|
||
logger.error(f"Validation error: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={
|
||
"status": "error",
|
||
"message": str(e)
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"An error occurred in the API endpoint: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={
|
||
"status": "error",
|
||
"message": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}
|
||
)
|
||
|
||
@app.get("/reports/{report_id}",
|
||
summary="下载测试报告",
|
||
description="根据报告ID下载对应的测试报告文件")
|
||
async def download_report(report_id: str, file_type: str = "summary.json"):
|
||
"""
|
||
下载测试报告文件
|
||
|
||
- report_id: 报告ID(通常是时间戳)
|
||
- file_type: 文件类型,可选值:summary.json, api_call_details.md
|
||
"""
|
||
try:
|
||
report_dir = Path("./test_reports") / report_id
|
||
file_path = report_dir / file_type
|
||
|
||
if not file_path.exists():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Report file not found: {file_type}"
|
||
)
|
||
|
||
return FileResponse(
|
||
path=str(file_path),
|
||
filename=file_type,
|
||
media_type='application/octet-stream'
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error downloading report: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Error downloading report: {str(e)}"
|
||
)
|
||
|
||
@app.get("/reports",
|
||
summary="列出所有测试报告",
|
||
description="获取所有可用的测试报告列表")
|
||
async def list_reports():
|
||
"""列出所有可用的测试报告"""
|
||
try:
|
||
reports_dir = Path("./test_reports")
|
||
if not reports_dir.exists():
|
||
return {"reports": []}
|
||
|
||
reports = []
|
||
for report_dir in reports_dir.iterdir():
|
||
if report_dir.is_dir():
|
||
summary_file = report_dir / "summary.json"
|
||
if summary_file.exists():
|
||
try:
|
||
with open(summary_file, 'r', encoding='utf-8') as f:
|
||
summary = json.load(f)
|
||
|
||
reports.append({
|
||
"id": report_dir.name,
|
||
"timestamp": report_dir.name,
|
||
"path": str(report_dir),
|
||
"summary": {
|
||
"endpoints_total": summary.get("endpoints_total", 0),
|
||
"endpoints_passed": summary.get("endpoints_passed", 0),
|
||
"endpoints_failed": summary.get("endpoints_failed", 0),
|
||
"test_cases_total": summary.get("test_cases_total", 0)
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"Error reading summary for {report_dir.name}: {e}")
|
||
|
||
# Sort by timestamp (newest first)
|
||
reports.sort(key=lambda x: x["timestamp"], reverse=True)
|
||
|
||
return {"reports": reports}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error listing reports: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Error listing reports: {str(e)}"
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="DMS合规性测试工具 FastAPI服务器")
|
||
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
|
||
parser.add_argument("--port", type=int, default=5050, help="服务器端口")
|
||
parser.add_argument("--reload", action="store_true", help="启用自动重载(开发模式)")
|
||
parser.add_argument("--workers", type=int, default=1, help="工作进程数")
|
||
|
||
args = parser.parse_args()
|
||
|
||
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
|
||
logger.info(f"API文档地址: http://{args.host}:{args.port}/docs")
|
||
logger.info(f"ReDoc文档地址: http://{args.host}:{args.port}/redoc")
|
||
|
||
uvicorn.run(
|
||
"fastapi_server:app",
|
||
host=args.host,
|
||
port=args.port,
|
||
reload=args.reload,
|
||
workers=args.workers if not args.reload else 1,
|
||
log_level="info"
|
||
)
|