#!/usr/bin/env python3 """ DMS合规性测试工具 - FastAPI版本API服务器 提供自动生成的交互式API文档 """ import os import sys import json import logging import datetime import traceback from pathlib import Path from typing import List, Optional, Dict, Any, Union import unicodedata import html # FastAPI imports from fastapi import FastAPI, HTTPException, BackgroundTasks, status from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, field_validator, model_validator import uvicorn # PDF generation libraries - with fallback try: from reportlab.lib import colors from reportlab.lib.pagesizes import A4 from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, HRFlowable from reportlab.pdfbase import pdfmetrics from reportlab.pdfbase.ttfonts import TTFont reportlab_available = True except ImportError: reportlab_available = False # Project-specific imports from ddms_compliance_suite.api_caller.caller import APICallDetail from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary from ddms_compliance_suite.input_parser.parser import ParsedAPISpec from ddms_compliance_suite.utils.response_utils import extract_data_for_validation from ddms_compliance_suite.utils.data_generator import DataGenerator # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # FastAPI app instance app = FastAPI( title="DMS合规性测试工具 API", description=""" DMS合规性测试工具 FastAPI版本 这是一个用于API合规性测试的工具,支持: YAPI规范测试 - 基于YAPI定义文件的测试 Swagger/OpenAPI测试 - 基于OpenAPI规范的测试 DMS服务发现测试 - 动态发现DMS服务的API进行测试 分页支持 - 支持大量API的分页获取,避免内存溢出 PDF报告生成 - 生成详细的测试报告 LLM集成 - 支持大语言模型辅助生成测试数据 主要特性 - 🚀 高性能: 基于FastAPI,支持异步处理 - 📊 分页支持: 解决大量API节点的内存问题 - 📝 自动文档: 自动生成交互式API文档 - 🔧 灵活配置: 支持多种测试配置选项 - 📈 详细报告: 生成PDF和JSON格式的测试报告 """, version="1.0.0", docs_url="/docs", # Swagger UI redoc_url="/redoc", # ReDoc ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # 在生产环境中应该限制具体域名 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic models for request/response class TestConfig(BaseModel): """测试配置模型""" # API定义源 (三选一) yapi: Optional[str] = Field(None, description="YAPI定义文件路径", example="./api_spec.json") swagger: Optional[str] = Field(None, description="Swagger/OpenAPI定义文件路径", example="./openapi.yaml") dms: Optional[str] = Field(None, description="DMS服务发现的domain mapping文件路径", example="./assets/doc/dms/domain.json") # 基本配置 base_url: str = Field(..., description="API基础URL", example="https://api.example.com") # 分页配置 page_size: int = Field(1000, description="DMS API分页大小,默认1000。较小的值可以减少内存使用", ge=1, le=10000) page_no: int = Field(1, description="起始页码,从1开始。可用于断点续传或跳过前面的页面", ge=1) fetch_all_pages: bool = Field(True, description="是否获取所有页面。True=获取所有数据,False=只获取指定页面") # 过滤选项 categories: Optional[List[str]] = Field(None, description="YAPI分类列表", example=["用户管理", "订单系统"]) tags: Optional[List[str]] = Field(None, description="Swagger标签列表", example=["user", "order"]) strictness_level: str = Field("CRITICAL", description="测试严格等级", pattern="^(CRITICAL|HIGH|MEDIUM|LOW)$") # SSL和安全 ignore_ssl: bool = Field(False, description="忽略SSL证书验证(不推荐在生产环境使用)") # 输出配置 output: str = Field("./test_reports", description="测试报告输出目录") generate_pdf: bool = Field(True, description="是否生成PDF报告") # 自定义测试 custom_test_cases_dir: Optional[str] = Field(None, description="自定义测试用例目录路径") stages_dir: Optional[str] = Field(None, description="自定义测试阶段目录路径") # LLM配置 llm_api_key: Optional[str] = Field(None, description="LLM API密钥") llm_base_url: Optional[str] = Field(None, description="LLM API基础URL") llm_model_name: Optional[str] = Field("gpt-3.5-turbo", description="LLM模型名称") use_llm_for_request_body: bool = Field(False, description="使用LLM生成请求体") use_llm_for_path_params: bool = Field(False, description="使用LLM生成路径参数") use_llm_for_query_params: bool = Field(False, description="使用LLM生成查询参数") use_llm_for_headers: bool = Field(False, description="使用LLM生成请求头") # 调试选项 verbose: bool = Field(False, description="启用详细日志输出") @field_validator('base_url') @classmethod def validate_base_url(cls, v): if not v.startswith(('http://', 'https://')): raise ValueError('base_url must start with http:// or https://') return v @model_validator(mode='before') @classmethod def validate_api_source(cls, values): """验证API定义源,确保三选一""" if isinstance(values, dict): api_sources = [values.get('yapi'), values.get('swagger'), values.get('dms')] non_none_sources = [s for s in api_sources if s is not None] if len(non_none_sources) > 1: raise ValueError('只能选择一个API定义源:yapi、swagger或dms') if len(non_none_sources) == 0: raise ValueError('必须提供一个API定义源:yapi、swagger或dms') return values class PaginationInfo(BaseModel): """分页信息模型""" page_size: int = Field(description="页面大小") page_no_start: int = Field(description="起始页码") total_pages: int = Field(description="总页数") total_records: int = Field(description="总记录数") pages_fetched: int = Field(description="已获取页数") current_page: int = Field(description="当前页码") class TestResponse(BaseModel): """测试响应模型""" status: str = Field(description="测试状态", example="completed") message: str = Field(description="状态消息") report_directory: str = Field(description="报告目录路径") summary: Dict[str, Any] = Field(description="测试摘要信息") pagination: Optional[PaginationInfo] = Field(None, description="分页信息(仅DMS测试时返回)") class ErrorResponse(BaseModel): """错误响应模型""" status: str = Field("error", description="错误状态") message: str = Field(description="错误消息") traceback: Optional[str] = Field(None, description="错误堆栈跟踪") # Global variable to store running tasks running_tasks: Dict[str, Dict[str, Any]] = {} @app.get("/", summary="健康检查", description="检查API服务器是否正常运行", response_model=Dict[str, str]) async def health_check(): """健康检查端点,用于Docker健康检查""" return { "status": "healthy", "service": "DMS Compliance API Server (FastAPI)", "version": "2.0.0", "docs_url": "/docs", "redoc_url": "/redoc" } @app.get("/info", summary="服务信息", description="获取API服务器的详细信息", response_model=Dict[str, Any]) async def get_info(): """获取服务器信息""" return { "service": "DMS Compliance API Server", "version": "2.0.0", "framework": "FastAPI", "features": [ "YAPI规范测试", "Swagger/OpenAPI测试", "DMS服务发现测试", "分页支持", "PDF报告生成", "LLM集成", "自动API文档" ], "endpoints": { "health": "/", "info": "/info", "run_tests": "/run", "docs": "/docs", "redoc": "/redoc" }, "reportlab_available": reportlab_available } # Import the test logic from the original Flask version def run_tests_logic(config: dict): """ Main logic for running tests, adapted from the original Flask version. """ try: if config.get('verbose'): logging.getLogger('ddms_compliance_suite').setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG) logger.debug("Verbose logging enabled.") if not any(k in config for k in ['yapi', 'swagger', 'dms']): raise ValueError("An API definition source is required: --yapi, --swagger, or --dms") if sum(k in config for k in ['yapi', 'swagger', 'dms']) > 1: raise ValueError("API definition sources are mutually exclusive.") # Setup output directory with timestamp base_output_dir = Path(config.get('output', './test_reports')) timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") output_directory = base_output_dir / timestamp output_directory.mkdir(parents=True, exist_ok=True) logger.info(f"Test reports will be saved to: {output_directory.resolve()}") # Initialize the orchestrator orchestrator = APITestOrchestrator( base_url=config['base_url'], custom_test_cases_dir=config.get('custom_test_cases_dir'), llm_api_key=config.get('llm_api_key'), llm_base_url=config.get('llm_base_url'), llm_model_name=config.get('llm_model_name'), use_llm_for_request_body=config.get('use_llm_for_request_body', False), use_llm_for_path_params=config.get('use_llm_for_path_params', False), use_llm_for_query_params=config.get('use_llm_for_query_params', False), use_llm_for_headers=config.get('use_llm_for_headers', False), output_dir=str(output_directory), stages_dir=config.get('stages_dir'), strictness_level=config.get('strictness_level', 'CRITICAL'), ignore_ssl=config.get('ignore_ssl', False) ) test_summary: Optional[TestSummary] = None parsed_spec: Optional[ParsedAPISpec] = None pagination_info: Dict[str, Any] = {} if 'yapi' in config: logger.info(f"Running tests from YAPI file: {config['yapi']}") test_summary, parsed_spec = orchestrator.run_tests_from_yapi( yapi_file_path=config['yapi'], categories=config.get('categories'), custom_test_cases_dir=config.get('custom_test_cases_dir') ) elif 'swagger' in config: logger.info(f"Running tests from Swagger file: {config['swagger']}") test_summary, parsed_spec = orchestrator.run_tests_from_swagger( swagger_file_path=config['swagger'], tags=config.get('tags'), custom_test_cases_dir=config.get('custom_test_cases_dir') ) elif 'dms' in config: logger.info(f"Running tests from DMS service discovery: {config['dms']}") test_summary, parsed_spec, pagination_info = orchestrator.run_tests_from_dms( domain_mapping_path=config['dms'], categories=config.get('categories'), custom_test_cases_dir=config.get('custom_test_cases_dir'), page_size=config.get('page_size', 1000), page_no_start=config.get('page_no', 1), fetch_all_pages=config.get('fetch_all_pages', True) ) if not parsed_spec: raise RuntimeError("Failed to parse the API specification.") if test_summary and config.get('stages_dir') and parsed_spec: logger.info(f"Executing API test stages from directory: {config['stages_dir']}") stage_summary = orchestrator.run_stages_from_spec(parsed_spec, config['stages_dir']) if stage_summary: test_summary.merge_stage_summary(stage_summary) if test_summary: # Save main summary main_report_file_path = output_directory / "summary.json" with open(main_report_file_path, 'w', encoding='utf-8') as f: f.write(test_summary.to_json(pretty=True)) # Save API call details api_calls_filename = "api_call_details.md" save_api_call_details_to_markdown( orchestrator.get_api_call_details(), str(output_directory), filename=api_calls_filename ) failed_count = getattr(test_summary, 'endpoints_failed', 0) + getattr(test_summary, 'test_cases_failed', 0) error_count = getattr(test_summary, 'endpoints_error', 0) + getattr(test_summary, 'test_cases_error', 0) result = { "status": "completed", "message": "Tests finished." if failed_count == 0 and error_count == 0 else "Tests finished with failures or errors.", "report_directory": str(output_directory.resolve()), "summary": test_summary.to_dict() } # 如果有分页信息,添加到返回结果中 if pagination_info: result["pagination"] = pagination_info return result else: raise RuntimeError("Test execution failed to produce a summary.") except Exception as e: logger.error(f"An unexpected error occurred during test execution: {e}", exc_info=True) return { "status": "error", "message": str(e), "traceback": traceback.format_exc() } def save_api_call_details_to_markdown(api_call_details: List[APICallDetail], output_dir: str, filename: str = "api_call_details.md"): """Save API call details to markdown file""" try: output_path = Path(output_dir) / filename with open(output_path, 'w', encoding='utf-8') as f: f.write("# API调用详情\n\n") for i, detail in enumerate(api_call_details, 1): f.write(f"## {i}. {detail.endpoint_name}\n\n") f.write(f"**请求URL**: `{detail.request_url}`\n\n") f.write(f"**请求方法**: `{detail.request_method}`\n\n") if detail.request_headers: f.write("**请求头**:\n```json\n") f.write(json.dumps(detail.request_headers, indent=2, ensure_ascii=False)) f.write("\n```\n\n") if detail.request_body: f.write("**请求体**:\n```json\n") f.write(json.dumps(detail.request_body, indent=2, ensure_ascii=False)) f.write("\n```\n\n") f.write(f"**响应状态码**: `{detail.response_status_code}`\n\n") if detail.response_headers: f.write("**响应头**:\n```json\n") f.write(json.dumps(detail.response_headers, indent=2, ensure_ascii=False)) f.write("\n```\n\n") if detail.response_body: f.write("**响应体**:\n```json\n") f.write(json.dumps(detail.response_body, indent=2, ensure_ascii=False)) f.write("\n```\n\n") f.write("---\n\n") logger.info(f"API call details saved to: {output_path}") except Exception as e: logger.error(f"Error saving API call details: {e}") @app.post("/run", summary="执行API合规性测试", description=""" 执行API合规性测试的主要端点。 支持三种API定义源: - **YAPI**: 基于YAPI定义文件 - **Swagger/OpenAPI**: 基于OpenAPI规范文件 - **DMS**: 动态发现DMS服务的API ### 分页支持 对于DMS测试,支持分页获取API列表,避免内存溢出: - `page_size`: 每页获取的API数量(默认1000) - 返回详细的分页统计信息 ### LLM集成 可选择使用大语言模型生成测试数据: - 智能生成请求体、路径参数、查询参数等 - 提高测试覆盖率和数据多样性 """, response_model=TestResponse, responses={ 200: {"description": "测试执行成功"}, 400: {"description": "请求参数错误", "model": ErrorResponse}, 500: {"description": "服务器内部错误", "model": ErrorResponse} }) async def run_api_tests(config: TestConfig): """ 执行API合规性测试 - **config**: 测试配置,包含API定义源、测试参数等 - **returns**: 测试结果,包含摘要信息和分页信息(如适用) """ try: logger.info(f"Starting test run with configuration: {config.model_dump()}") # Convert Pydantic model to dict for compatibility config_dict = config.model_dump(exclude_none=True) # Replace underscores with hyphens for compatibility with original code config_dict = {k.replace('_', '-'): v for k, v in config_dict.items()} result = run_tests_logic(config_dict) if result['status'] == 'error': raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=result ) return result except ValueError as e: logger.error(f"Validation error: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={ "status": "error", "message": str(e) } ) except Exception as e: logger.error(f"An error occurred in the API endpoint: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={ "status": "error", "message": str(e), "traceback": traceback.format_exc() } ) @app.get("/reports/{report_id}", summary="下载测试报告", description="根据报告ID下载对应的测试报告文件") async def download_report(report_id: str, file_type: str = "summary.json"): """ 下载测试报告文件 - **report_id**: 报告ID(通常是时间戳) - **file_type**: 文件类型,可选值:summary.json, api_call_details.md """ try: report_dir = Path("./test_reports") / report_id file_path = report_dir / file_type if not file_path.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Report file not found: {file_type}" ) return FileResponse( path=str(file_path), filename=file_type, media_type='application/octet-stream' ) except Exception as e: logger.error(f"Error downloading report: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error downloading report: {str(e)}" ) @app.get("/reports", summary="列出所有测试报告", description="获取所有可用的测试报告列表") async def list_reports(): """列出所有可用的测试报告""" try: reports_dir = Path("./test_reports") if not reports_dir.exists(): return {"reports": []} reports = [] for report_dir in reports_dir.iterdir(): if report_dir.is_dir(): summary_file = report_dir / "summary.json" if summary_file.exists(): try: with open(summary_file, 'r', encoding='utf-8') as f: summary = json.load(f) reports.append({ "id": report_dir.name, "timestamp": report_dir.name, "path": str(report_dir), "summary": { "endpoints_total": summary.get("endpoints_total", 0), "endpoints_passed": summary.get("endpoints_passed", 0), "endpoints_failed": summary.get("endpoints_failed", 0), "test_cases_total": summary.get("test_cases_total", 0) } }) except Exception as e: logger.warning(f"Error reading summary for {report_dir.name}: {e}") # Sort by timestamp (newest first) reports.sort(key=lambda x: x["timestamp"], reverse=True) return {"reports": reports} except Exception as e: logger.error(f"Error listing reports: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error listing reports: {str(e)}" ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="DMS合规性测试工具 FastAPI服务器") parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址") parser.add_argument("--port", type=int, default=5050, help="服务器端口") parser.add_argument("--reload", action="store_true", help="启用自动重载(开发模式)") parser.add_argument("--workers", type=int, default=1, help="工作进程数") args = parser.parse_args() logger.info(f"Starting FastAPI server on {args.host}:{args.port}") logger.info(f"API文档地址: http://{args.host}:{args.port}/docs") logger.info(f"ReDoc文档地址: http://{args.host}:{args.port}/redoc") uvicorn.run( "fastapi_server:app", host=args.host, port=args.port, reload=args.reload, workers=args.workers if not args.reload else 1, log_level="info" )