compliance/ddms_compliance_suite/test_orchestrator.py
2025-08-07 23:54:35 +08:00

2700 lines
167 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试编排器模块
负责组合API解析器、API调用器、验证器和规则执行器进行端到端的API测试
"""
import logging
import json
import time
import os # Added os for path operations
import re
from typing import Dict, List, Any, Optional, Union, Tuple, Type, ForwardRef
from enum import Enum
import datetime
import datetime as dt
from uuid import UUID
from dataclasses import asdict as dataclass_asdict, is_dataclass
import copy
from collections import defaultdict
from pathlib import Path
from urllib.parse import urljoin # <-- ADDED
from pydantic import BaseModel, Field, create_model, HttpUrl # Added HttpUrl for Literal type hint if needed
from pydantic.networks import EmailStr
from pydantic.types import Literal # Explicitly import Literal
from .input_parser.parser import InputParser, YAPIEndpoint, SwaggerEndpoint, ParsedYAPISpec, ParsedSwaggerSpec, ParsedAPISpec, DMSEndpoint, ParsedDMSSpec
from .input_parser.parser import BaseEndpoint, DMSEndpoint
from .api_caller.caller import APICaller, APIRequest, APIResponse, APICallDetail # Ensure APICallDetail is imported
from .json_schema_validator.validator import JSONSchemaValidator
from .test_framework_core import ValidationResult, TestSeverity, APIRequestContext, APIResponseContext, BaseAPITestCase
from .test_case_registry import TestCaseRegistry
from .utils import schema_utils
from .utils.common_utils import format_url_with_path_params
# 新增导入
from .stage_framework import BaseAPIStage, ExecutedStageResult, ExecutedStageStepResult, StageStepDefinition, APIOperationSpec
from .stage_registry import StageRegistry # For managing stages
try:
from .llm_utils.llm_service import LLMService
except ImportError:
LLMService = None
logging.getLogger(__name__).info("LLMService 未找到LLM 相关功能将不可用。")
from ddms_compliance_suite.utils.schema_provider import SchemaProvider
_dynamic_model_cache: Dict[str, Type[BaseModel]] = {}
class ExecutedTestCaseResult:
"""存储单个APITestCase在其适用的端点上执行后的结果。"""
class Status(str, Enum):
"""单个测试用例的执行状态枚举"""
PASSED = "通过"
FAILED = "失败"
ERROR = "执行错误" # 指测试用例代码本身出错而不是API验证失败
SKIPPED = "跳过" # 如果测试用例因某些条件被跳过执行
def __init__(self,
test_case_id: str,
test_case_name: str,
test_case_severity: TestSeverity,
status: Status,
validation_points: List[ValidationResult],
message: str = "", # 总体消息,例如执行错误时的错误信息
duration: float = 0.0):
self.test_case_id = test_case_id
self.test_case_name = test_case_name
self.test_case_severity = test_case_severity
self.status = status
self.validation_points = validation_points or []
self.message = message
self.duration = duration # 执行此测试用例的耗时
self.timestamp = datetime.datetime.now()
def to_dict(self) -> Dict[str, Any]:
message = self.message
if not message and self.validation_points:
# Revert to dictionary access since self.validation_points contains dicts
failed_messages = [vp.get("message") for vp in self.validation_points if isinstance(vp, dict) and not vp.get("passed") and vp.get("message")]
if failed_messages:
message = "; ".join(failed_messages)
else:
# Fallback message if no specific failure messages are available
message = "One or more validation points failed without a detailed message." if self.status == self.Status.FAILED else "All validation points passed."
return {
"test_case_id": self.test_case_id,
"test_case_name": self.test_case_name,
"test_case_severity": self.test_case_severity.name,
"status": self.status.value,
"message": message,
"duration_seconds": self.duration,
"timestamp": self.timestamp.isoformat(),
# The list already contains dictionaries, so just return it
"validation_points": self.validation_points
}
class TestResult: # 原来的 TestResult 被重构为 EndpointExecutionResult
"""
存储对单个API端点执行所有适用APITestCase后的整体测试结果。
(此类替换了旧的 TestResult 的角色,并进行了结构调整)
"""
class Status(str, Enum): # 这个枚举保持不变,但其含义现在是端点的整体状态
"""端点测试状态枚举"""
PASSED = "通过" # 所有关键测试用例通过
FAILED = "失败" # 任何一个关键测试用例失败
ERROR = "错误" # 测试执行过程中出现错误非API本身错误而是测试代码或环境
SKIPPED = "跳过" # 如果整个端点的测试被跳过
PARTIAL_SUCCESS = "部分成功" # 一些非关键测试用例失败,但关键的通过
def __init__(self,
endpoint_id: str, # 通常是 method + path
endpoint_name: str, # API 的可读名称/标题
overall_status: Status = Status.SKIPPED, # 默认为跳过,后续根据测试用例结果更新
start_time: Optional[datetime.datetime] = None
):
self.endpoint_id = endpoint_id
self.endpoint_name = endpoint_name
self.overall_status = overall_status
self.executed_test_cases: List[ExecutedTestCaseResult] = []
self.start_time = start_time if start_time else datetime.datetime.now()
self.end_time: Optional[datetime.datetime] = None
self.error_message: Optional[str] = None # 如果整个端点测试出错,记录错误信息
self.message: Optional[str] = None
self.strictness_level: Optional[TestSeverity] = None
def add_executed_test_case_result(self, result: ExecutedTestCaseResult):
self.executed_test_cases.append(result)
def finalize_endpoint_test(self, strictness_level: Optional[TestSeverity] = None):
self.end_time = datetime.datetime.now()
self.strictness_level = strictness_level
# 检查是否有测试用例执行出错
if any(tc.status == ExecutedTestCaseResult.Status.ERROR for tc in self.executed_test_cases):
self.overall_status = TestResult.Status.ERROR
first_error = next((tc.message for tc in self.executed_test_cases if tc.status == ExecutedTestCaseResult.Status.ERROR), "Unknown test case error")
self.error_message = f"测试用例执行错误: {first_error}"
return
# 如果没有执行任何测试用例
if not self.executed_test_cases:
if self.overall_status == TestResult.Status.SKIPPED:
# 保持 SKIPPED 状态
return
else:
self.overall_status = TestResult.Status.ERROR
self.error_message = "没有为该端点找到或执行任何适用的测试用例。"
return
# 根据 strictness_level 决定最终状态
failed_tcs = [tc for tc in self.executed_test_cases if tc.status == ExecutedTestCaseResult.Status.FAILED]
if not failed_tcs:
self.overall_status = TestResult.Status.PASSED
return
logging.info(f"strictness_level: {strictness_level}")
# 如果定义了严格等级,只关心高于或等于该等级的失败用例
if self.strictness_level:
# TestSeverity Enum is ordered, so we can compare them
logging.info(f"strictness_level: {strictness_level}")
relevant_failed_tcs = [
tc for tc in failed_tcs
if tc.test_case_severity >= strictness_level
]
if not relevant_failed_tcs:
self.overall_status = TestResult.Status.PASSED
self.message = f"通过(严格等级: {strictness_level.name})。注意:存在 {len(failed_tcs)} 个较低严重性的失败用例。"
else:
self.overall_status = TestResult.Status.FAILED
self.message = f"失败(严格等级: {strictness_level.name})。"
logging.info(f"relevant_failed_tcs: {relevant_failed_tcs}")
else:
# 默认行为:任何失败都可能导致 FAILED 或 PARTIAL_SUCCESS
if any(tc.test_case_severity in [TestSeverity.CRITICAL, TestSeverity.HIGH] for tc in failed_tcs):
self.overall_status = TestResult.Status.FAILED
else:
self.overall_status = TestResult.Status.PARTIAL_SUCCESS
@property
def duration(self) -> float:
if self.start_time and self.end_time:
return (self.end_time - self.start_time).total_seconds()
return 0.0
def to_dict(self) -> Dict[str, Any]:
data = {
"endpoint_id": self.endpoint_id,
"endpoint_name": self.endpoint_name,
"overall_status": self.overall_status.value,
"duration_seconds": self.duration,
"start_time": self.start_time.isoformat() if self.start_time else None,
"end_time": self.end_time.isoformat() if self.end_time else None,
"executed_test_cases": [tc.to_dict() for tc in self.executed_test_cases]
}
if self.error_message:
data["error_message"] = self.error_message
return data
class TestSummary:
"""测试结果摘要 (已更新以适应新的结果结构)"""
def __init__(self):
self.total_endpoints_defined: int = 0
self.total_endpoints_tested: int = 0
self.endpoints_passed: int = 0
self.endpoints_failed: int = 0
self.endpoints_partial_success: int = 0
self.endpoints_error: int = 0
self.endpoints_skipped: int = 0
self.total_test_cases_applicable: int = 0
self.total_test_cases_executed: int = 0
self.test_cases_passed: int = 0
self.test_cases_failed: int = 0
self.test_cases_error: int = 0
self.test_cases_skipped_in_endpoint: int = 0
self.total_stages_defined: int = 0
self.total_stages_executed: int = 0
self.stages_passed: int = 0
self.stages_failed: int = 0
self.stages_error: int = 0
self.stages_skipped: int = 0
self.start_time = datetime.datetime.now()
self.end_time: Optional[datetime.datetime] = None
self.detailed_results: List[TestResult] = []
self.stage_results: List[ExecutedStageResult] = []
self.api_call_details_summary: List[Dict[str, Any]] = []
self.errors: List[str] = []
self.logger = logging.getLogger(__name__)
self.logger.info("TestSummary initialized.")
def add_endpoint_result(self, result: TestResult): # result 现在是新的 TestResult 类型
self.detailed_results.append(result)
if result.executed_test_cases or result.overall_status not in [TestResult.Status.SKIPPED, TestResult.Status.ERROR]: # 只有实际尝试了测试的端点才算tested
if not (len(result.executed_test_cases) == 0 and result.overall_status == TestResult.Status.ERROR and result.error_message and "没有为该端点找到或执行任何适用的测试用例" in result.error_message):
self.total_endpoints_tested +=1
if result.overall_status == TestResult.Status.PASSED:
self.endpoints_passed += 1
elif result.overall_status == TestResult.Status.FAILED:
self.endpoints_failed += 1
elif result.overall_status == TestResult.Status.PARTIAL_SUCCESS:
self.endpoints_partial_success +=1
elif result.overall_status == TestResult.Status.ERROR:
self.endpoints_error += 1
elif result.overall_status == TestResult.Status.SKIPPED: # 端点级别跳过
self.endpoints_skipped +=1
for tc_result in result.executed_test_cases:
self.total_test_cases_executed += 1 # 每个APITestCase算一次执行
if tc_result.status == ExecutedTestCaseResult.Status.PASSED:
self.test_cases_passed += 1
elif tc_result.status == ExecutedTestCaseResult.Status.FAILED:
self.test_cases_failed += 1
elif tc_result.status == ExecutedTestCaseResult.Status.ERROR:
self.test_cases_error +=1
elif tc_result.status == ExecutedTestCaseResult.Status.SKIPPED:
self.test_cases_skipped_in_endpoint +=1
def set_total_endpoints_defined(self, count: int):
self.total_endpoints_defined = count
def set_total_test_cases_applicable(self, count: int):
self.total_test_cases_applicable = count
# 新增:用于场景统计的方法 -> 修改为 Stage 统计方法
def add_stage_result(self, result: ExecutedStageResult):
self.stage_results.append(result)
if result.overall_status == ExecutedStageResult.Status.PASSED:
self.stages_passed += 1
elif result.overall_status == ExecutedStageResult.Status.FAILED:
self.stages_failed += 1
elif result.overall_status == ExecutedStageResult.Status.ERROR: # Assuming ERROR maps to a fail count for summary
self.stages_error +=1 # or self.stages_failed +=1 depending on how you want to count errors
elif result.overall_status == ExecutedStageResult.Status.SKIPPED:
self.stages_skipped += 1
self.total_stages_executed +=1
def set_total_stages_defined(self, count: int):
self.total_stages_defined = count
def finalize_summary(self):
self.end_time = datetime.datetime.now()
@property
def duration(self) -> float:
if not self.end_time:
return 0.0
return (self.end_time - self.start_time).total_seconds()
@property
def endpoint_success_rate(self) -> float:
if self.total_endpoints_tested == 0:
return 0.0
# 通常只把 PASSED 算作成功
return (self.endpoints_passed / self.total_endpoints_tested) * 100
@property
def test_case_success_rate(self) -> float:
if self.total_test_cases_executed == 0:
return 0.0
return (self.test_cases_passed / self.total_test_cases_executed) * 100
def to_dict(self) -> Dict[str, Any]:
data = {
"start_time": self.start_time.isoformat(),
"end_time": self.end_time.isoformat() if self.end_time else None,
"duration_seconds": f"{self.duration:.2f}",
"overall_summary": {
"total_endpoints_defined": self.total_endpoints_defined,
"endpoints_tested": len(self.detailed_results),
"endpoints_passed": self.endpoints_passed,
"endpoints_failed": self.endpoints_failed,
"endpoints_error": self.endpoints_error,
"endpoints_skipped": self.endpoints_skipped,
"endpoints_partial_success": self.endpoints_partial_success,
"endpoint_success_rate": f"{self.endpoint_success_rate:.2f}%",
"total_test_cases_applicable": self.total_test_cases_applicable,
"total_test_cases_executed": self.total_test_cases_executed,
"test_cases_passed": self.test_cases_passed,
"test_cases_failed": self.test_cases_failed,
"test_cases_error": self.test_cases_error,
"test_cases_skipped_in_endpoint": self.test_cases_skipped_in_endpoint,
"test_case_success_rate": f"{self.test_case_success_rate:.2f}%",
"total_stages_defined": self.total_stages_defined,
"total_stages_executed": self.total_stages_executed,
"stages_passed": self.stages_passed,
"stages_failed": self.stages_failed,
"stages_error": self.stages_error,
"stages_skipped": self.stages_skipped,
"stage_success_rate": f"{self.stage_success_rate:.2f}%" if self.total_stages_executed > 0 else "N/A",
},
"errors": self.errors,
"endpoint_results": [res.to_dict() for res in self.detailed_results],
"stage_results": [res.to_dict() for res in self.stage_results],
"api_call_details_summary": self.api_call_details_summary
}
return data
def to_json(self, pretty=True) -> str:
indent = 2 if pretty else None
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
def print_summary_to_console(self): # Renamed from print_summary
# (Implementation can be more detailed based on the new stats)
self.logger.info("-------------------- API Test Summary --------------------")
self.logger.info(f"Start Time: {self.start_time.isoformat()}")
if self.end_time:
self.logger.info(f"End Time: {self.end_time.isoformat()}")
self.logger.info(f"Total Duration: {self.duration:.2f} seconds")
self.logger.info("--- Endpoints ---")
self.logger.info(f"Total Endpoints Defined: {self.total_endpoints_defined}")
self.logger.info(f"Endpoints Tested: {len(self.detailed_results)}")
self.logger.info(f" Passed: {self.endpoints_passed}")
self.logger.info(f" Failed: {self.endpoints_failed}")
self.logger.info(f" Error: {self.endpoints_error}")
self.logger.info(f" Skipped: {self.endpoints_skipped}")
self.logger.info(f" Partial Success: {self.endpoints_partial_success}")
self.logger.info(f" Success Rate: {self.endpoint_success_rate:.2f}%")
self.logger.info("--- Test Cases ---")
self.logger.info(f"Total Test Cases Applicable: {self.total_test_cases_applicable}")
self.logger.info(f"Total Test Cases Executed: {self.total_test_cases_executed}")
self.logger.info(f" Passed: {self.test_cases_passed}")
self.logger.info(f" Failed: {self.test_cases_failed}")
self.logger.info(f" Error: {self.test_cases_error}")
self.logger.info(f" Skipped (within endpoints): {self.test_cases_skipped_in_endpoint}")
self.logger.info(f" Success Rate: {self.test_case_success_rate:.2f}%")
self.logger.info("--- Stages ---")
self.logger.info(f"Total Stages Defined: {self.total_stages_defined}")
self.logger.info(f"Total Stages Executed: {self.total_stages_executed}")
self.logger.info(f" Passed: {self.stages_passed}")
self.logger.info(f" Failed: {self.stages_failed}")
self.logger.info(f" Error: {self.stages_error}")
self.logger.info(f" Skipped: {self.stages_skipped}")
if self.total_stages_executed > 0:
self.logger.info(f" Success Rate: {self.stage_success_rate:.2f}%")
else:
self.logger.info(" Success Rate: N/A (No stages executed)")
if self.errors:
self.logger.error("--- Orchestration Errors ---")
for err in self.errors:
self.logger.error(err)
self.logger.info("--------------------------------------------------------")
def add_api_call_details(self, details: Dict[str, Any]):
self.api_call_details_summary.append(details)
def add_error(self, error_message: str):
self.errors.append(error_message)
@property
def stage_success_rate(self) -> float:
if self.total_stages_executed == 0:
return 0.0
return (self.stages_passed / self.total_stages_executed) * 100
class APITestOrchestrator:
"""
测试编排器负责协调整个API测试流程。
包括:
1. 解析API定义 (YAPI, Swagger)
2. 加载自定义测试用例 (BaseAPITestCase)
3. 执行测试用例并收集结果
4. 加载和执行API场景 (BaseAPIScenario) - 已实现
5. 加载和执行API测试阶段 (BaseAPIStage) - 新增
6. 生成测试报告和API调用详情
"""
MAX_RECURSION_DEPTH_PYDANTIC = 10 # 新增一个常量用于 Pydantic 模型创建的递归深度限制
def __init__(self, base_url: str,
custom_test_cases_dir: Optional[str] = None,
stages_dir: Optional[str] = None,
llm_api_key: Optional[str] = None,
llm_base_url: Optional[str] = None,
llm_model_name: Optional[str] = None,
use_llm_for_request_body: bool = False,
use_llm_for_path_params: bool = False,
use_llm_for_query_params: bool = False,
use_llm_for_headers: bool = False,
stage_llm_config: Optional[Dict[str, bool]] = None,
output_dir: Optional[str] = None,
strictness_level: Optional[str] = None,
ignore_ssl: bool = False
):
"""
初始化测试编排器。
Args:
base_url (str): API的基础URL。
custom_test_cases_dir (Optional[str]): 存放自定义测试用例的目录。
stages_dir (Optional[str]): 存放自定义测试阶段的目录。
llm_api_key (Optional[str]): LLM服务的API密钥。
llm_base_url (Optional[str]): LLM服务的自定义基础URL。
llm_model_name (Optional[str]): 要使用的LLM模型名称。
use_llm_for_request_body (bool): 是否使用LLM生成请求体。
use_llm_for_path_params (bool): 是否使用LLM生成路径参数。
use_llm_for_query_params (bool): 是否使用LLM生成查询参数。
use_llm_for_headers (bool): 是否使用LLM生成头部参数。
output_dir (Optional[str]): 测试报告和工件的输出目录。
strictness_level (Optional[str]): 测试的严格等级, 如 'CRITICAL', 'HIGH'
ignore_ssl (bool): 是否忽略SSL证书验证。
"""
self.logger = logging.getLogger(__name__)
self.base_url = base_url.rstrip('/')
self.api_caller = APICaller()
self.test_case_registry = TestCaseRegistry(test_cases_dir=custom_test_cases_dir)
self.global_api_call_details: List[APICallDetail] = []
self.ignore_ssl = ignore_ssl
self.stages_dir = stages_dir
self.stage_registry: Optional[StageRegistry] = None
self.llm_service: Optional[LLMService] = None
# 普通测试用例的LLM配置
self.llm_config = {
"use_for_request_body": use_llm_for_request_body,
"use_for_path_params": use_llm_for_path_params,
"use_for_query_params": use_llm_for_query_params,
"use_for_headers": use_llm_for_headers,
}
# Stage专用的LLM配置
self.stage_llm_config = stage_llm_config or {
"use_llm_for_request_body": False,
"use_llm_for_path_params": False,
"use_llm_for_query_params": False,
"use_llm_for_headers": False,
}
if llm_api_key and llm_base_url and LLMService: # <-- MODIFIED: Added check for llm_base_url
try:
self.llm_service = LLMService(api_key=llm_api_key, base_url=llm_base_url, model_name=llm_model_name)
self.logger.info(f"LLMService initialized with model: {self.llm_service.model_name}.")
if not any(self.llm_config.values()):
self.logger.info("LLMService is initialized, but no LLM generation flags (--use-llm-for-*) are enabled.")
except Exception as e:
self.logger.error(f"Failed to initialize LLMService: {e}. LLM features will be disabled.", exc_info=True)
self.llm_service = None
elif not LLMService and any(self.llm_config.values()):
self.logger.warning("LLM usage flags are set, but LLMService components are not available. LLM features will be disabled.")
for key in self.llm_config: self.llm_config[key] = False
elif not llm_api_key:
self.logger.info("LLM API key not provided. LLM-based data generation will be disabled.")
for key in self.llm_config: self.llm_config[key] = False
if self.stages_dir:
self.stage_registry = StageRegistry(stages_dir=self.stages_dir)
if self.stage_registry and self.stage_registry.get_discovery_errors():
for err in self.stage_registry.get_discovery_errors():
self.logger.error(f"Error loading stage: {err}")
elif self.stage_registry:
self.logger.info(f"StageRegistry initialized. Loaded {len(self.stage_registry.get_all_stages())} stages.") # Changed from get_all_stage_classes
else:
self.logger.info("No stages_dir provided, stage testing will be skipped.")
self.output_dir_path = Path(output_dir) if output_dir else Path("./test_reports_orchestrator")
try:
self.output_dir_path.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Orchestrator output directory set to: {self.output_dir_path.resolve()}")
except OSError as e:
self.logger.warning(f"Could not create orchestrator output directory {self.output_dir_path}: {e}. Falling back to current directory.")
self.output_dir_path = Path(".")
self.schema_cache: Dict[str, Type[BaseModel]] = {}
self.model_name_counts: Dict[str, int] = defaultdict(int)
self.parser = InputParser()
self.json_resolver_cache: Dict[str, Any] = {}
self.json_validator = JSONSchemaValidator()
# 将字符串类型的 strictness_level 转换为 TestSeverity 枚举成员
self.strictness_level: Optional[TestSeverity] = None
if strictness_level and hasattr(TestSeverity, strictness_level):
self.strictness_level = TestSeverity[strictness_level]
logging.info(f"strictness_level: {self.strictness_level}")
elif strictness_level:
logging.warning(f"提供了无效的严格等级 '{strictness_level}'。将使用默认行为。有效值: {', '.join([e.name for e in TestSeverity])}")
# 将这些属性的初始化移到此处并设为None避免在_execute_tests_from_parsed_spec之前被错误使用
self.json_schema_validator: Optional[JSONSchemaValidator] = None
self.schema_provider: Optional[SchemaProvider] = None
def get_api_call_details(self) -> List[APICallDetail]:
"""Returns the collected list of API call details."""
return self.global_api_call_details
def _should_use_llm_for_param_type(
self,
param_type_key: str, # 例如 "path_params", "query_params", "headers", "body"
test_case_instance: Optional[BaseAPITestCase]
) -> bool:
"""
判断是否应为特定参数类型尝试使用LLM。
结合全局配置和测试用例特定配置。
"""
if not self.llm_service: # 如果LLM服务本身就不可用则肯定不用
return False
global_flag = False
tc_specific_flag: Optional[bool] = None
if param_type_key == "body":
global_flag = self.llm_config["use_for_request_body"]
if test_case_instance:
tc_specific_flag = test_case_instance.use_llm_for_body
elif param_type_key == "path_params":
global_flag = self.llm_config["use_for_path_params"]
if test_case_instance:
tc_specific_flag = test_case_instance.use_llm_for_path_params
elif param_type_key == "query_params":
global_flag = self.llm_config["use_for_query_params"]
if test_case_instance:
tc_specific_flag = test_case_instance.use_llm_for_query_params
elif param_type_key == "headers":
global_flag = self.llm_config["use_for_headers"]
if test_case_instance:
tc_specific_flag = test_case_instance.use_llm_for_headers
else:
self.logger.warning(f"未知的参数类型键 '{param_type_key}' 在 _should_use_llm_for_param_type 中检查。")
return False
# 决定最终是否使用LLM的逻辑
# 1. 如果测试用例明确设置了 (tc_specific_flag is not None),则以测试用例的设置为准。
# 2. 否则,使用全局设置。
final_decision = tc_specific_flag if tc_specific_flag is not None else global_flag
# self.logger.debug(f"LLM决策 for '{param_type_key}': TC specific='{tc_specific_flag}', Global='{global_flag}', Final='{final_decision}')
return final_decision
def _create_pydantic_model_from_schema(
self,
schema: Dict[str, Any],
model_name: str,
recursion_depth: int = 0
) -> Optional[Type[BaseModel]]:
"""
Dynamically creates a Pydantic model from a JSON schema.
Handles nested schemas, arrays, and various OpenAPI/JSON Schema constructs.
Uses a cache (_dynamic_model_cache) to avoid redefining identical models.
"""
# This cache key generation might need refinement for very complex/deep schemas
# For now, using a combination of model_name and sorted schema keys/values
# Important: dicts are unhashable, so we convert to a sorted tuple of items for the cache key.
# This is a simplified cache key; a more robust approach might involve serializing the schema.
# schema_tuple_for_key = tuple(sorted(schema.items())) if isinstance(schema, dict) else schema
# cache_key = (model_name, schema_tuple_for_key, recursion_depth) # Might be too verbose/complex
# Simpler cache key based on model_name only if we assume model_name is sufficiently unique
# for a given schema structure within a run. If schemas can change for the same model_name,
# this needs to be more sophisticated.
# If model_name is unique per structure, this is fine.
# Let's assume model_name is carefully constructed to be unique for each distinct schema structure
# by the calling functions (e.g., _generate_data_from_schema, _build_object_schema_for_params).
# Simplified approach: if a model with this exact name was already created, reuse it.
# This relies on the caller to ensure `model_name` is unique per schema structure.
if model_name in _dynamic_model_cache:
self.logger.debug(f"Reusing cached Pydantic model: {model_name}")
return _dynamic_model_cache[model_name]
if recursion_depth > self.MAX_RECURSION_DEPTH_PYDANTIC:
self.logger.error(f"创建Pydantic模型 '{model_name}' 时达到最大递归深度 {self.MAX_RECURSION_DEPTH_PYDANTIC}。可能存在循环引用。")
return None
# 清理模型名称使其成为有效的Python标识符
safe_model_name = "".join(c if c.isalnum() or c == '_' else '_' for c in model_name)
if not safe_model_name or not safe_model_name[0].isalpha() and safe_model_name[0] != '_':
safe_model_name = f"DynamicModel_{safe_model_name}"
# 检查缓存 (使用清理后的名称)
if safe_model_name in _dynamic_model_cache:
self.logger.debug(f"从缓存返回动态模型: {safe_model_name}")
return _dynamic_model_cache[safe_model_name]
self.logger.debug(f"开始从Schema创建Pydantic模型: '{safe_model_name}' (原始名: '{model_name}', 深度: {recursion_depth})")
if not isinstance(schema, dict) or schema.get('type') != 'object':
# Safely get type for logging if schema is not a dict or does not have 'type'
schema_type_for_log = schema.get('type') if isinstance(schema, dict) else type(schema).__name__
self.logger.error(f"提供的Schema用于模型 '{safe_model_name}' 的必须是 type 'object' 且是一个字典, 实际: {schema_type_for_log}")
return None
properties = schema.get('properties', {})
required_fields = set(schema.get('required', []))
field_definitions: Dict[str, Tuple[Any, Any]] = {}
for prop_name, prop_schema in properties.items():
if not isinstance(prop_schema, dict):
self.logger.warning(f"属性 '{prop_name}' 在模型 '{safe_model_name}' 中的Schema无效已跳过。")
continue
python_type: Any = Any
field_args: Dict[str, Any] = {}
default_value: Any = ... # Ellipsis for required fields with no default
if 'default' in prop_schema:
default_value = prop_schema['default']
elif prop_name not in required_fields:
default_value = None
if 'description' in prop_schema:
field_args['description'] = prop_schema['description']
json_type = prop_schema.get('type')
json_format = prop_schema.get('format')
if json_type == 'object':
nested_model_name_base = f"{safe_model_name}_{prop_name}"
python_type = self._create_pydantic_model_from_schema(prop_schema, nested_model_name_base, recursion_depth + 1)
if python_type is None:
self.logger.warning(f"无法为 '{safe_model_name}' 中的嵌套属性 '{prop_name}' 创建模型,已跳过。")
continue
elif json_type == 'array':
items_schema = prop_schema.get('items')
if not isinstance(items_schema, dict):
self.logger.warning(f"数组属性 '{prop_name}' 在模型 '{safe_model_name}' 中的 'items' schema无效已跳过。")
continue
item_type: Any = Any
item_json_type = items_schema.get('type')
item_json_format = items_schema.get('format')
if item_json_type == 'object':
item_model_name_base = f"{safe_model_name}_{prop_name}_Item"
item_type = self._create_pydantic_model_from_schema(items_schema, item_model_name_base, recursion_depth + 1)
if item_type is None:
self.logger.warning(f"无法为 '{safe_model_name}' 中的数组属性 '{prop_name}' 的项创建模型,已跳过。")
continue
elif item_json_type == 'string':
if item_json_format == 'date-time': item_type = dt.datetime
elif item_json_format == 'date': item_type = dt.date
elif item_json_format == 'email': item_type = EmailStr
elif item_json_format == 'uuid': item_type = UUID
else: item_type = str
elif item_json_type == 'integer': item_type = int
elif item_json_type == 'number': item_type = float
elif item_json_type == 'boolean': item_type = bool
else:
self.logger.warning(f"数组 '{prop_name}' 中的项具有未知类型 '{item_json_type}',默认为 Any。")
python_type = List[item_type] # type: ignore
elif json_type == 'string':
if json_format == 'date-time': python_type = dt.datetime
elif json_format == 'date': python_type = dt.date
elif json_format == 'email': python_type = EmailStr
elif json_format == 'uuid': python_type = UUID
else: python_type = str
if 'minLength' in prop_schema: field_args['min_length'] = prop_schema['minLength']
if 'maxLength' in prop_schema: field_args['max_length'] = prop_schema['maxLength']
if 'pattern' in prop_schema: field_args['pattern'] = prop_schema['pattern']
elif json_type == 'integer':
python_type = int
if 'minimum' in prop_schema: field_args['ge'] = prop_schema['minimum']
if 'maximum' in prop_schema: field_args['le'] = prop_schema['maximum']
elif json_type == 'number':
python_type = float
if 'minimum' in prop_schema: field_args['ge'] = prop_schema['minimum']
if 'maximum' in prop_schema: field_args['le'] = prop_schema['maximum']
elif json_type == 'boolean':
python_type = bool
elif json_type is None and '$ref' in prop_schema:
self.logger.warning(f"Schema $ref '{prop_schema['$ref']}' in '{safe_model_name}.{prop_name}' not yet supported. Defaulting to Any.")
python_type = Any
else:
self.logger.warning(f"属性 '{prop_name}' 在模型 '{safe_model_name}' 中具有未知类型 '{json_type}',默认为 Any。")
python_type = Any
if 'enum' in prop_schema:
enum_values = prop_schema['enum']
if enum_values:
enum_desc = f" (Enum values: {', '.join(map(str, enum_values))})"
field_args['description'] = field_args.get('description', '') + enum_desc
current_field_is_optional = prop_name not in required_fields
if current_field_is_optional and python_type is not Any and default_value is None:
# For Pydantic v1/v2, if a field is not required and has no other default, it's Optional.
# The `python_type` itself might already be an `Optional` if it came from a nested optional model.
# We only wrap with Optional if it's not already wrapped effectively.
# A simple check: if the type name doesn't start with "Optional"
if not (hasattr(python_type, '__origin__') and python_type.__origin__ is Union and type(None) in python_type.__args__):
python_type = Optional[python_type]
field_definitions[prop_name] = (python_type, Field(default_value, **field_args))
if not field_definitions:
self.logger.warning(f"模型 '{safe_model_name}' 没有有效的字段定义,无法创建。")
# Return a very basic BaseModel if no properties are defined but an object schema was given
# This might happen for an empty object schema {}
try:
EmptyModel = create_model(safe_model_name, __base__=BaseModel)
_dynamic_model_cache[safe_model_name] = EmptyModel
self.logger.info(f"创建了一个空的动态Pydantic模型: '{safe_model_name}' (由于无属性定义)")
return EmptyModel
except Exception as e_empty:
self.logger.error(f"尝试为 '{safe_model_name}' 创建空模型时失败: {e_empty}", exc_info=True)
return None
try:
# ForwardRef for self-referencing models is complex; not fully handled here yet.
# If a type in field_definitions is a string (e.g., a ForwardRef string), create_model handles it.
DynamicModel = create_model(safe_model_name, **field_definitions, __base__=BaseModel) # type: ignore
_dynamic_model_cache[safe_model_name] = DynamicModel
self.logger.info(f"成功创建/缓存了动态Pydantic模型: '{safe_model_name}'")
# Attempt to update forward refs if any were string types that are now defined
# This is a simplified approach. Pydantic's update_forward_refs is usually called on the module or specific model.
# For dynamically created models, this might need careful handling if true circular deps are common.
# For now, we assume nested creation order mostly handles dependencies.
# if hasattr(DynamicModel, 'update_forward_refs'):
# try:
# DynamicModel.update_forward_refs(**_dynamic_model_cache)
# self.logger.debug(f"Attempted to update forward refs for {safe_model_name}")
# except Exception as e_fwd:
# self.logger.warning(f"Error updating forward_refs for {safe_model_name}: {e_fwd}")
return DynamicModel
except Exception as e:
self.logger.error(f"使用Pydantic create_model创建 '{safe_model_name}' 时失败: {e}", exc_info=True)
return None
def _execute_single_test_case(
self,
test_case_class: Type[BaseAPITestCase],
endpoint_spec: Union[YAPIEndpoint, SwaggerEndpoint, DMSEndpoint], # 当前端点的规格
global_api_spec: Union[ParsedYAPISpec, ParsedSwaggerSpec, ParsedDMSSpec] # 整个API的规格
) -> ExecutedTestCaseResult:
"""
执行单个测试用例。
流程:
1. 准备请求数据 (路径参数, 查询参数, 请求头, 请求体)。
- 首先尝试从测试用例的 generate_xxx 方法获取。
- 如果测试用例未覆盖或返回None则尝试从API spec生成默认数据。
- 如果开启了LLM并且测试用例允许则使用LLM生成。
2. (如果适用) 调用测试用例的 modify_request_url 钩子。
3. (如果适用) 调用测试用例的 validate_request_url, validate_request_headers, validate_request_body 钩子。
4. 发送API请求。
5. 记录响应。
6. 调用测试用例的 validate_response 和 check_performance 钩子。
7. 汇总验证结果,确定测试用例状态。
"""
start_time = time.monotonic()
validation_results: List[ValidationResult] = []
overall_status: ExecutedTestCaseResult.Status
execution_message = ""
test_case_instance: Optional[BaseAPITestCase] = None # Initialize to None
# 将 endpoint_spec 转换为字典,如果它还不是的话
endpoint_spec_dict: Dict[str, Any]
if isinstance(endpoint_spec, dict):
endpoint_spec_dict = endpoint_spec
# self.logger.debug(f"endpoint_spec 已经是字典类型。")
elif hasattr(endpoint_spec, 'to_dict') and callable(endpoint_spec.to_dict):
try:
endpoint_spec_dict = endpoint_spec.to_dict()
# self.logger.debug(f"成功通过 to_dict() 方法将类型为 {type(endpoint_spec)} 的 endpoint_spec 转换为字典。")
if not endpoint_spec_dict: # 如果 to_dict() 返回空字典
# self.logger.warning(f"endpoint_spec.to_dict() (类型: {type(endpoint_spec)}) 返回了一个空字典。")
# 尝试备用转换
if isinstance(endpoint_spec, (YAPIEndpoint, SwaggerEndpoint, DMSEndpoint)):
# self.logger.debug(f"尝试从 {type(endpoint_spec).__name__} 对象的属性手动构建 endpoint_spec_dict。")
endpoint_spec_dict = {
"method": getattr(endpoint_spec, 'method', 'UNKNOWN_METHOD').upper(),
"path": getattr(endpoint_spec, 'path', 'UNKNOWN_PATH'),
"title": getattr(endpoint_spec, 'title', getattr(endpoint_spec, 'summary', '')),
"summary": getattr(endpoint_spec, 'summary', ''),
"description": getattr(endpoint_spec, 'description', ''),
"operationId": getattr(endpoint_spec, 'operation_id', f"{getattr(endpoint_spec, 'method', '').upper()}_{getattr(endpoint_spec, 'path', '').replace('/', '_')}"),
"parameters": getattr(endpoint_spec, 'parameters', []) if hasattr(endpoint_spec, 'parameters') else (getattr(endpoint_spec, 'req_query', []) + getattr(endpoint_spec, 'req_headers', [])),
"requestBody": getattr(endpoint_spec, 'request_body', None) if hasattr(endpoint_spec, 'request_body') else getattr(endpoint_spec, 'req_body_other', None),
"_original_object_type": type(endpoint_spec).__name__
}
if not any(endpoint_spec_dict.values()): # 如果手动构建后仍基本为空
# self.logger.error(f"手动从属性构建 endpoint_spec_dict (类型: {type(endpoint_spec)}) 后仍然为空或无效。")
endpoint_spec_dict = {} # 重置为空,触发下方错误处理
except Exception as e:
self.logger.error(f"调用 endpoint_spec (类型: {type(endpoint_spec)}) 的 to_dict() 方法时出错: {e}。尝试备用转换。")
if isinstance(endpoint_spec, (YAPIEndpoint, SwaggerEndpoint, DMSEndpoint)):
self.logger.debug(f"尝试从 {type(endpoint_spec).__name__} 对象的属性手动构建 endpoint_spec_dict。")
endpoint_spec_dict = {
"method": getattr(endpoint_spec, 'method', 'UNKNOWN_METHOD').upper(),
"path": getattr(endpoint_spec, 'path', 'UNKNOWN_PATH'),
"title": getattr(endpoint_spec, 'title', getattr(endpoint_spec, 'summary', '')),
"summary": getattr(endpoint_spec, 'summary', ''),
"description": getattr(endpoint_spec, 'description', ''),
"operationId": getattr(endpoint_spec, 'operation_id', f"{getattr(endpoint_spec, 'method', '').upper()}_{getattr(endpoint_spec, 'path', '').replace('/', '_')}"),
"parameters": getattr(endpoint_spec, 'parameters', []) if hasattr(endpoint_spec, 'parameters') else (getattr(endpoint_spec, 'req_query', []) + getattr(endpoint_spec, 'req_headers', [])),
"requestBody": getattr(endpoint_spec, 'request_body', None) if hasattr(endpoint_spec, 'request_body') else getattr(endpoint_spec, 'req_body_other', None),
"_original_object_type": type(endpoint_spec).__name__
}
if not any(endpoint_spec_dict.values()): # 如果手动构建后仍基本为空
self.logger.error(f"手动从属性构建 endpoint_spec_dict (类型: {type(endpoint_spec)}) 后仍然为空或无效。")
endpoint_spec_dict = {} # 重置为空,触发下方错误处理
else:
endpoint_spec_dict = {} # 转换失败
elif hasattr(endpoint_spec, 'data') and isinstance(getattr(endpoint_spec, 'data'), dict): # 兼容 YAPIEndpoint 结构
endpoint_spec_dict = getattr(endpoint_spec, 'data')
# self.logger.debug(f"使用了类型为 {type(endpoint_spec)} 的 endpoint_spec 的 .data 属性。")
else: # 如果没有 to_dict, 也不是已知可直接访问 .data 的类型,则尝试最后的通用转换或手动构建
if isinstance(endpoint_spec, (YAPIEndpoint, SwaggerEndpoint, DMSEndpoint)):
# self.logger.debug(f"类型为 {type(endpoint_spec).__name__} 的 endpoint_spec 没有 to_dict() 或 data尝试从属性手动构建。")
endpoint_spec_dict = {
"method": getattr(endpoint_spec, 'method', 'UNKNOWN_METHOD').upper(),
"path": getattr(endpoint_spec, 'path', 'UNKNOWN_PATH'),
"title": getattr(endpoint_spec, 'title', getattr(endpoint_spec, 'summary', '')),
"summary": getattr(endpoint_spec, 'summary', ''),
"description": getattr(endpoint_spec, 'description', ''),
"operationId": getattr(endpoint_spec, 'operation_id', f"{getattr(endpoint_spec, 'method', '').upper()}_{getattr(endpoint_spec, 'path', '').replace('/', '_')}"),
"parameters": getattr(endpoint_spec, 'parameters', []) if hasattr(endpoint_spec, 'parameters') else (getattr(endpoint_spec, 'req_query', []) + getattr(endpoint_spec, 'req_headers', [])),
"requestBody": getattr(endpoint_spec, 'request_body', None) if hasattr(endpoint_spec, 'request_body') else getattr(endpoint_spec, 'req_body_other', None),
"_original_object_type": type(endpoint_spec).__name__
}
if not any(endpoint_spec_dict.values()): # 如果手动构建后仍基本为空
self.logger.error(f"手动从属性构建 endpoint_spec_dict (类型: {type(endpoint_spec)}) 后仍然为空或无效。")
endpoint_spec_dict = {} # 重置为空,触发下方错误处理
else:
try:
endpoint_spec_dict = dict(endpoint_spec)
self.logger.warning(f"直接将类型为 {type(endpoint_spec)} 的 endpoint_spec 转换为字典。这可能是一个浅拷贝,并且可能不完整。")
except TypeError:
self.logger.error(f"无法将 endpoint_spec (类型: {type(endpoint_spec)}) 转换为字典,也未找到有效的转换方法。")
endpoint_spec_dict = {}
if not endpoint_spec_dict or not endpoint_spec_dict.get("path") or endpoint_spec_dict.get("path") == 'UNKNOWN_PATH': # 如果转换后仍为空或无效
self.logger.error(f"Endpoint spec (原始类型: {type(endpoint_spec)}) 无法有效转换为包含有效路径的字典,测试用例执行可能受影响。最终 endpoint_spec_dict: {endpoint_spec_dict}")
# 创建一个最小的 endpoint_spec_dict 以允许测试用例实例化,但它将缺少大部分信息
endpoint_spec_dict = {
'method': endpoint_spec_dict.get('method', 'UNKNOWN_METHOD'), # 保留已解析的方法
'path': 'UNKNOWN_PATH_CONVERSION_FAILED',
'title': f"Unknown endpoint due to spec conversion error for original type {type(endpoint_spec)}",
'parameters': [], # 确保有空的 parameters 和 requestBody
'requestBody': None
}
# 确保 global_api_spec (应该是 ParsedSwaggerSpec 或 ParsedYAPISpec 实例) 被转换为字典
global_spec_dict: Dict[str, Any] = {}
converted_by_method: Optional[str] = None
if hasattr(global_api_spec, 'spec') and isinstance(getattr(global_api_spec, 'spec', None), dict) and getattr(global_api_spec, 'spec', None):
global_spec_dict = global_api_spec.spec # type: ignore
converted_by_method = ".spec attribute"
elif is_dataclass(global_api_spec) and not isinstance(global_api_spec, type): # Ensure it's an instance, not the class itself
try:
candidate_spec = dataclass_asdict(global_api_spec)
if isinstance(candidate_spec, dict) and candidate_spec:
global_spec_dict = candidate_spec
converted_by_method = "dataclasses.asdict()"
except Exception as e:
self.logger.debug(f"Calling dataclasses.asdict() on {type(global_api_spec)} failed: {e}, trying other methods.")
if not global_spec_dict and hasattr(global_api_spec, 'model_dump') and callable(global_api_spec.model_dump):
try:
candidate_spec = global_api_spec.model_dump()
if isinstance(candidate_spec, dict) and candidate_spec:
global_spec_dict = candidate_spec
converted_by_method = ".model_dump()"
except Exception as e:
self.logger.debug(f"Calling .model_dump() on {type(global_api_spec)} failed: {e}, trying other methods.")
if not global_spec_dict and hasattr(global_api_spec, 'dict') and callable(global_api_spec.dict):
try:
candidate_spec = global_api_spec.dict()
if isinstance(candidate_spec, dict) and candidate_spec:
global_spec_dict = candidate_spec
converted_by_method = ".dict()"
except Exception as e:
self.logger.debug(f"Calling .dict() on {type(global_api_spec)} failed: {e}, trying other methods.")
if not global_spec_dict and hasattr(global_api_spec, 'to_dict') and callable(global_api_spec.to_dict):
try:
candidate_spec = global_api_spec.to_dict()
if isinstance(candidate_spec, dict) and candidate_spec:
global_spec_dict = candidate_spec
converted_by_method = ".to_dict()"
except Exception as e:
self.logger.debug(f"Calling .to_dict() on {type(global_api_spec)} failed: {e}, trying other methods.")
if not global_spec_dict and isinstance(global_api_spec, dict) and global_api_spec:
global_spec_dict = global_api_spec
converted_by_method = "direct dict"
self.logger.warning(f"global_api_spec was already a dictionary. This might be unexpected if an object was anticipated.")
if global_spec_dict and converted_by_method:
self.logger.debug(f"Successfully converted/retrieved global_api_spec (type: {type(global_api_spec)}) to dict using {converted_by_method}.")
elif not global_spec_dict :
self.logger.error(
f"Failed to convert global_api_spec (type: {type(global_api_spec)}) to a non-empty dictionary using .spec, dataclasses.asdict(), .model_dump(), .dict(), or .to_dict(). "
f"It's also not a non-empty dictionary itself. JSON reference resolution will be severely limited or fail. Using empty global_spec_dict."
)
global_spec_dict = {}
# --- BEGIN $ref RESOLUTION ---
if global_spec_dict: # Only attempt resolution if we have the full spec for lookups
self.logger.debug(f"global_spec_dict keys for $ref resolution: {list(global_spec_dict.keys())}") # <--- 添加的日志行
self.logger.debug(f"开始为 endpoint_spec_dict (来自 {type(endpoint_spec)}) 中的 schemas 进行 $ref 解析...")
# 1. 解析 requestBody schema
if 'requestBody' in endpoint_spec_dict and isinstance(endpoint_spec_dict['requestBody'], dict):
if 'content' in endpoint_spec_dict['requestBody'] and isinstance(endpoint_spec_dict['requestBody']['content'], dict):
for media_type, media_type_obj in endpoint_spec_dict['requestBody']['content'].items():
if isinstance(media_type_obj, dict) and 'schema' in media_type_obj:
self.logger.debug(f"正在解析 requestBody content '{media_type}' 的 schema...")
original_schema = media_type_obj['schema']
media_type_obj['schema'] = schema_utils.resolve_json_schema_references(original_schema, global_spec_dict)
# self.logger.debug(f"解析后的 requestBody content '{media_type}' schema: {json.dumps(media_type_obj['schema'], indent=2)}")
# 2. 解析 parameters schemas (OpenAPI 2.0 'in: body' parameter or OpenAPI 3.0 parameters)
if 'parameters' in endpoint_spec_dict and isinstance(endpoint_spec_dict['parameters'], list):
for i, param in enumerate(endpoint_spec_dict['parameters']):
if isinstance(param, dict) and 'schema' in param:
self.logger.debug(f"正在解析 parameters[{i}] ('{param.get('name', 'N/A')}') 的 schema...")
original_param_schema = param['schema']
param['schema'] = schema_utils.resolve_json_schema_references(original_param_schema, global_spec_dict)
# self.logger.debug(f"解析后的 parameters[{i}] schema: {json.dumps(param['schema'], indent=2)}")
# 3. 解析 responses schemas
if 'responses' in endpoint_spec_dict and isinstance(endpoint_spec_dict['responses'], dict):
for status_code, response_obj in endpoint_spec_dict['responses'].items():
if isinstance(response_obj, dict) and 'content' in response_obj and isinstance(response_obj['content'], dict):
for media_type, media_type_obj in response_obj['content'].items():
if isinstance(media_type_obj, dict) and 'schema' in media_type_obj:
self.logger.debug(f"正在解析 responses '{status_code}' content '{media_type}' 的 schema...")
original_resp_schema = media_type_obj['schema']
media_type_obj['schema'] = schema_utils.resolve_json_schema_references(original_resp_schema, global_spec_dict)
# self.logger.debug(f"解析后的 response '{status_code}' content '{media_type}' schema: {json.dumps(media_type_obj['schema'], indent=2)}")
# OpenAPI 2.0 response schema directly under response object
elif isinstance(response_obj, dict) and 'schema' in response_obj:
self.logger.debug(f"正在解析 responses '{status_code}' 的 schema (OpenAPI 2.0 style)...")
original_resp_schema = response_obj['schema']
response_obj['schema'] = schema_utils.resolve_json_schema_references(original_resp_schema, global_spec_dict)
self.logger.info(f"Endpoint spec (来自 {type(endpoint_spec)}) 中的 schemas $ref 解析完成。")
else:
self.logger.warning(f"global_spec_dict 为空,跳过 endpoint_spec_dict (来自 {type(endpoint_spec)}) 的 $ref 解析。")
# --- END $ref RESOLUTION ---
# 将 global_spec_dict 注入到 endpoint_spec_dict 中,供可能的内部解析使用 (如果 to_dict 未包含它)
if '_global_api_spec_for_resolution' not in endpoint_spec_dict and global_spec_dict:
endpoint_spec_dict['_global_api_spec_for_resolution'] = global_spec_dict
try:
self.logger.debug(f"准备实例化测试用例类: {test_case_class.__name__} 使用 endpoint_spec (keys: {list(endpoint_spec_dict.keys()) if endpoint_spec_dict else 'None'}) 和 global_api_spec (keys: {list(global_spec_dict.keys()) if global_spec_dict else 'None'})")
test_case_instance = test_case_class(
endpoint_spec=endpoint_spec_dict,
global_api_spec=global_spec_dict,
json_schema_validator=self.json_validator,
llm_service=self.llm_service
)
self.logger.info(f"开始执行测试用例 '{test_case_instance.id}' ({test_case_instance.name}) for endpoint '{endpoint_spec_dict.get('method', 'N/A')} {endpoint_spec_dict.get('path', 'N/A')}'")
# 调用 _prepare_initial_request_data 时传递 test_case_instance
# 并直接解包返回的元组
request_context_data = self._prepare_initial_request_data(endpoint_spec_dict, test_case_instance=test_case_instance)
# 从 request_context_data 对象中获取各个部分
method = request_context_data.method
path_params_data = request_context_data.path_params
query_params_data = request_context_data.query_params
headers_data = request_context_data.headers
body_data = request_context_data.body
# 让测试用例有机会修改这些生成的数据
# 注意: BaseAPITestCase 中的 generate_* 方法现在需要传入 endpoint_spec_dict
# 因为它们可能需要原始的端点定义来进行更复杂的逻辑
current_q_params = test_case_instance.generate_query_params(query_params_data)
current_headers = test_case_instance.generate_headers(headers_data)
current_body = test_case_instance.generate_request_body(body_data)
# 路径参数通常由编排器根据路径模板和数据最终确定,但如果测试用例要覆盖,可以提供 generate_path_params
# 这里我们使用从 _prepare_initial_request_data 返回的 path_params_data 作为基础
current_path_params = test_case_instance.generate_path_params(path_params_data) if hasattr(test_case_instance, 'generate_path_params') and callable(getattr(test_case_instance, 'generate_path_params')) and getattr(test_case_instance, 'generate_path_params').__func__ != BaseAPITestCase.generate_path_params else path_params_data
final_url_template = endpoint_spec_dict.get('path', '')
# 添加日志:打印将要用于替换的路径参数
self.logger.debug(f"Path parameters to be substituted: {current_path_params}")
final_url = self.base_url + final_url_template
for p_name, p_val in current_path_params.items():
placeholder = f"{{{p_name}}}"
if placeholder in final_url_template: # 检查原始模板中是否存在占位符
final_url = final_url.replace(placeholder, str(p_val))
# 添加日志打印替换后的URL (在测试用例修改之前)
self.logger.debug(f"URL after path parameter substitution (before TC modify_request_url hook): {final_url}")
# ---- 调用测试用例的 URL 修改钩子 ----
effective_url = final_url # 默认使用原始构建的URL
if hasattr(test_case_instance, 'modify_request_url') and callable(getattr(test_case_instance, 'modify_request_url')):
try:
modified_url_by_tc = test_case_instance.modify_request_url(final_url)
if modified_url_by_tc != final_url:
test_case_instance.logger.info(f"Test case '{test_case_instance.id}' modified URL from '{final_url}' to '{modified_url_by_tc}'")
effective_url = modified_url_by_tc # 使用测试用例修改后的URL
else:
test_case_instance.logger.debug(f"Test case '{test_case_instance.id}' did not modify the URL via modify_request_url hook.")
except Exception as e_url_mod:
test_case_instance.logger.error(f"Error in test case '{test_case_instance.id}' during modify_request_url: {e_url_mod}. Using original URL '{final_url}'.", exc_info=True)
# effective_url 保持为 final_url
else:
test_case_instance.logger.debug(f"Test case '{test_case_instance.id}' does not have a callable modify_request_url method. Using original URL.")
# ---- 结束 URL 修改钩子调用 ----
api_request_context = APIRequestContext(
method=method, # 使用从 _prepare_initial_request_data 获取的 method
url=effective_url, # <--- 使用 effective_url
path_params=current_path_params,
query_params=current_q_params,
headers=current_headers,
body=current_body,
endpoint_spec=endpoint_spec_dict
)
validation_results.extend(test_case_instance.validate_request_url(api_request_context.url, api_request_context))
validation_results.extend(test_case_instance.validate_request_headers(api_request_context.headers, api_request_context))
validation_results.extend(test_case_instance.validate_request_body(api_request_context.body, api_request_context))
critical_pre_validation_failure = False
failure_messages = []
for vp in validation_results:
if not vp.passed and test_case_instance.severity in [TestSeverity.CRITICAL, TestSeverity.HIGH]: # Check severity of the Test Case for pre-validation
critical_pre_validation_failure = True
failure_messages.append(vp.message)
if critical_pre_validation_failure:
self.logger.warning(f"测试用例 '{test_case_instance.id}' 因请求预校验失败而中止 (TC严重级别: {test_case_instance.severity.value})。失败信息: {'; '.join(failure_messages)}")
tc_duration = time.monotonic() - start_time
return ExecutedTestCaseResult(
test_case_id=test_case_instance.id,
test_case_name=test_case_instance.name,
test_case_severity=test_case_instance.severity,
status=ExecutedTestCaseResult.Status.FAILED,
validation_points=[vp.to_dict() for vp in validation_results],
message=f"请求预校验失败: {'; '.join(failure_messages)}",
duration=tc_duration
)
api_request_obj = APIRequest(
method=api_request_context.method,
url=api_request_context.url,
params=api_request_context.query_params,
headers=api_request_context.headers,
json_data=api_request_context.body
)
response_call_start_time = time.time()
# api_response_obj = self.api_caller.call_api(api_request_obj)
api_response, api_call_detail = self.api_caller.call_api(api_request_obj)
self.global_api_call_details.append(api_call_detail) # 记录日志
response_call_elapsed_time = time.time() - response_call_start_time
actual_text_content: Optional[str] = None
# 使用解包后的 api_response:
if hasattr(api_response, 'text_content') and api_response.text_content is not None:
actual_text_content = api_response.text_content
elif api_response.json_content is not None: # <--- 使用 api_response
if isinstance(api_response.json_content, str):
actual_text_content = api_response.json_content
else:
try:
actual_text_content = json.dumps(api_response.json_content, ensure_ascii=False)
except TypeError:
actual_text_content = str(api_response.json_content)
api_response_context = APIResponseContext(
status_code=api_response.status_code, # <--- 使用 api_response
headers=api_response.headers, # <--- 使用 api_response
json_content=api_response.json_content, # <--- 使用 api_response
text_content=actual_text_content,
elapsed_time=response_call_elapsed_time,
original_response= getattr(api_response, 'raw_response', None), # <--- 使用 api_response
request_context=api_request_context
)
validation_results.extend(test_case_instance.validate_response(api_response_context, api_request_context))
validation_results.extend(test_case_instance.check_performance(api_response_context, api_request_context))
final_status = ExecutedTestCaseResult.Status.PASSED
if any(not vp.passed for vp in validation_results):
final_status = ExecutedTestCaseResult.Status.FAILED
tc_duration = time.monotonic() - start_time
return ExecutedTestCaseResult(
test_case_id=test_case_instance.id,
test_case_name=test_case_instance.name,
test_case_severity=test_case_instance.severity,
status=final_status,
validation_points=[vp.to_dict() for vp in validation_results],
duration=tc_duration
)
except Exception as e:
self.logger.error(f"执行测试用例 '{test_case_class.id if hasattr(test_case_class, 'id') else test_case_class.__name__}' (在实例化阶段或之前) 时发生严重错误: {e}", exc_info=True)
# 如果 test_case_instance 在实例化时失败,它将是 None
tc_id_for_log = test_case_instance.id if test_case_instance else (test_case_class.id if hasattr(test_case_class, 'id') else "unknown_tc_id_instantiation_error")
tc_name_for_log = test_case_instance.name if test_case_instance else (test_case_class.name if hasattr(test_case_class, 'name') else test_case_class.__name__)
# 实例化失败严重性默认为CRITICAL
tc_severity_for_log = test_case_instance.severity if test_case_instance else TestSeverity.CRITICAL
tc_duration = time.monotonic() - start_time
# validation_results 可能在此阶段为空,或包含来自先前步骤的条目(如果错误发生在实例化之后)
return ExecutedTestCaseResult(
test_case_id=tc_id_for_log,
test_case_name=tc_name_for_log,
test_case_severity=tc_severity_for_log,
status=ExecutedTestCaseResult.Status.ERROR,
validation_points=[vp.to_dict() for vp in validation_results], # Ensure validation_results is defined (it is, at the start of the function)
message=f"测试用例执行时发生内部错误 (可能在实例化期间): {str(e)}",
duration=tc_duration
)
def _prepare_initial_request_data(
self,
endpoint_spec: Dict[str, Any], # 已经转换为字典
test_case_instance: Optional[BaseAPITestCase] = None # 传入测试用例实例以便访问其LLM配置
) -> APIRequestContext: # 返回 APIRequestContext 对象
"""
根据API端点规范准备初始的请求数据包括URL模板、路径参数、查询参数、头部和请求体。
这些数据将作为测试用例中 generate_* 方法的输入。
"""
method = endpoint_spec.get("method", "GET").upper()
path_template = endpoint_spec.get("path", "/")
operation_id = endpoint_spec.get("operationId", path_template) # 使用 path 作为 operationId 的 fallback
initial_path_params: Dict[str, Any] = {}
initial_query_params: Dict[str, Any] = {}
initial_headers: Dict[str, str] = {}
initial_body: Optional[Any] = None
parameters = endpoint_spec.get('parameters', [])
# 1. 处理路径参数
path_param_specs = [p for p in parameters if p.get('in') == 'path']
if path_param_specs:
should_use_llm = self._should_use_llm_for_param_type("path_params", test_case_instance)
if should_use_llm and self.llm_service:
self.logger.info(f"Attempting LLM generation for path parameters in '{operation_id}'")
path_schema, path_model_name = self._build_object_schema_for_params(path_param_specs, f"{operation_id}_PathParams")
if path_schema:
llm_path_params = self.llm_service.generate_data_from_schema(
path_schema,
prompt_instruction=None,
max_tokens=256,
temperature=0.1
)
if llm_path_params:
initial_path_params = llm_path_params
else:
self.logger.warning(f"LLM failed to generate path params for '{operation_id}', fallback to default.")
else:
self.logger.warning(f"Failed to build schema for path params in '{operation_id}', fallback to default.")
if not initial_path_params: # fallback
for param_spec in path_param_specs:
name = param_spec.get('name')
if not name: continue
if 'example' in param_spec:
initial_path_params[name] = param_spec['example']
elif param_spec.get('schema') and 'example' in param_spec['schema']:
initial_path_params[name] = param_spec['schema']['example'] # OpenAPI 3.0 `parameter.schema.example`
elif 'default' in param_spec.get('schema', {}):
initial_path_params[name] = param_spec['schema']['default']
elif 'default' in param_spec: # OpenAPI 2.0 `parameter.default`
initial_path_params[name] = param_spec['default']
else:
schema = param_spec.get('schema', {})
param_type = schema.get('type', 'string')
if param_type == 'integer': initial_path_params[name] = 123
elif param_type == 'number': initial_path_params[name] = 1.23
elif param_type == 'boolean': initial_path_params[name] = True
elif param_type == 'string' and schema.get('format') == 'uuid': initial_path_params[name] = str(UUID(int=0)) # Example UUID
elif param_type == 'string' and schema.get('format') == 'date': initial_path_params[name] = dt.date.today().isoformat()
elif param_type == 'string' and schema.get('format') == 'date-time': initial_path_params[name] = dt.datetime.now().isoformat()
else: initial_path_params[name] = f"param_{name}"
self.logger.debug(f"Initial path param for '{operation_id}': {initial_path_params}")
# 2. 处理查询参数
query_param_specs = [p for p in parameters if p.get('in') == 'query']
if query_param_specs:
should_use_llm = self._should_use_llm_for_param_type("query_params", test_case_instance)
if should_use_llm and self.llm_service:
self.logger.info(f"Attempting LLM generation for query parameters in '{operation_id}'")
query_schema, query_model_name = self._build_object_schema_for_params(query_param_specs, f"{operation_id}_QueryParams")
if query_schema:
llm_query_params = self.llm_service.generate_data_from_schema(
query_schema,
prompt_instruction=None,
max_tokens=512,
temperature=0.1
)
if llm_query_params:
initial_query_params = llm_query_params
else:
self.logger.warning(f"LLM failed to generate query params for '{operation_id}', fallback to default.")
if not initial_query_params: # fallback
for param_spec in query_param_specs:
name = param_spec.get('name')
if not name: continue
if 'example' in param_spec:
initial_query_params[name] = param_spec['example']
elif param_spec.get('schema') and 'example' in param_spec['schema']:
initial_query_params[name] = param_spec['schema']['example']
elif 'default' in param_spec.get('schema', {}):
initial_query_params[name] = param_spec['schema']['default']
elif 'default' in param_spec:
initial_query_params[name] = param_spec['default']
else:
initial_query_params[name] = f"query_val_{name}"
self.logger.debug(f"Initial query param for '{operation_id}': {initial_query_params}")
# 3. 处理请求头参数 (包括规范定义的和标准的 Content-Type/Accept)
header_param_specs = [p for p in parameters if p.get('in') == 'header']
custom_header_param_specs = [p for p in header_param_specs if p.get('name', '').lower() not in ['content-type', 'accept', 'authorization']]
if custom_header_param_specs:
should_use_llm = self._should_use_llm_for_param_type("headers", test_case_instance)
if should_use_llm and self.llm_service:
self.logger.info(f"Attempting LLM generation for header parameters in '{operation_id}'")
header_schema, header_model_name = self._build_object_schema_for_params(custom_header_param_specs, f"{operation_id}_HeaderParams")
if header_schema:
llm_header_params = self.llm_service.generate_data_from_schema(
header_schema,
prompt_instruction=None,
max_tokens=256,
temperature=0.1
)
if llm_header_params:
for k, v in llm_header_params.items():
initial_headers[k] = str(v)
else:
self.logger.warning(f"LLM failed to generate header params for '{operation_id}', fallback to default.")
if not any(k for k in initial_headers if k.lower() not in ['content-type', 'accept', 'authorization']): # fallback
for param_spec in custom_header_param_specs:
name = param_spec.get('name')
if not name: continue
if 'example' in param_spec:
initial_headers[name] = str(param_spec['example'])
elif param_spec.get('schema') and 'example' in param_spec['schema']:
initial_headers[name] = str(param_spec['schema']['example'])
elif 'default' in param_spec.get('schema', {}):
initial_headers[name] = str(param_spec['schema']['default'])
elif 'default' in param_spec:
initial_headers[name] = str(param_spec['default'])
else:
initial_headers[name] = f"header_val_{name}"
self.logger.debug(f"Initial custom header param for '{operation_id}': {initial_headers}")
# 3.1 设置 Content-Type
# 优先从 requestBody.content 获取 (OpenAPI 3.x)
request_body_spec_candidate = endpoint_spec.get('requestBody')
request_body_spec = request_body_spec_candidate if isinstance(request_body_spec_candidate, dict) else {}
if 'content' in request_body_spec:
content_types = list(request_body_spec['content'].keys())
if content_types:
# 优先选择 application/json 如果存在
initial_headers['Content-Type'] = next((ct for ct in content_types if 'json' in ct.lower()), content_types[0])
elif 'consumes' in endpoint_spec: # 然后是 consumes (OpenAPI 2.0)
consumes = endpoint_spec['consumes']
if consumes:
initial_headers['Content-Type'] = next((c for c in consumes if 'json' in c.lower()), consumes[0])
elif method in ['POST', 'PUT', 'PATCH'] and not initial_headers.get('Content-Type'):
initial_headers['Content-Type'] = 'application/json' # 默认对于这些方法
self.logger.debug(f"Initial Content-Type for '{operation_id}': {initial_headers.get('Content-Type')}")
# 3.2 设置 Accept
# 优先从 responses.<code>.content 获取 (OpenAPI 3.x)
responses_spec = endpoint_spec.get('responses', {})
accept_header_set = False
for code, response_def in responses_spec.items():
if 'content' in response_def:
accept_types = list(response_def['content'].keys())
if accept_types:
initial_headers['Accept'] = next((at for at in accept_types if 'json' in at.lower() or '*/*' in at), accept_types[0])
accept_header_set = True
break
if not accept_header_set and 'produces' in endpoint_spec: # 然后是 produces (OpenAPI 2.0)
produces = endpoint_spec['produces']
if produces:
initial_headers['Accept'] = next((p for p in produces if 'json' in p.lower() or '*/*' in p), produces[0])
accept_header_set = True
if not accept_header_set and not initial_headers.get('Accept'):
initial_headers['Accept'] = 'application/json, */*' # 更通用的默认值
self.logger.debug(f"Initial Accept header for '{operation_id}': {initial_headers.get('Accept')}")
# 4. 处理请求体 (Body)
request_body_schema: Optional[Dict[str, Any]] = None
# 确定请求体 schema 的来源,优先 OpenAPI 3.x 的 requestBody
content_type_for_body_schema = initial_headers.get('Content-Type', 'application/json').split(';')[0].strip()
if 'content' in request_body_spec and content_type_for_body_schema in request_body_spec['content']:
request_body_schema = request_body_spec['content'][content_type_for_body_schema].get('schema')
elif 'parameters' in endpoint_spec: # OpenAPI 2.0 (Swagger) body parameter
body_param = next((p for p in parameters if p.get('in') == 'body'), None)
if body_param and 'schema' in body_param:
request_body_schema = body_param['schema']
if request_body_schema:
should_use_llm_for_body = self._should_use_llm_for_param_type("body", test_case_instance)
if should_use_llm_for_body and self.llm_service:
self.logger.info(f"Attempting LLM generation for request body of '{operation_id}' with schema...")
initial_body = self.llm_service.generate_data_from_schema(
request_body_schema,
prompt_instruction=None, # 如有自定义指令可替换
max_tokens=1024,
temperature=0.1
)
if initial_body is None:
self.logger.warning(f"LLM failed to generate request body for '{operation_id}'. Falling back to default schema generator.")
initial_body = self._generate_data_from_schema(request_body_schema, context_name=f"{operation_id}_body", operation_id=operation_id)
else:
initial_body = self._generate_data_from_schema(request_body_schema, context_name=f"{operation_id}_body", operation_id=operation_id)
self.logger.debug(f"Initial request body generated for '{operation_id}' (type: {type(initial_body)})")
else:
self.logger.debug(f"No request body schema found or applicable for '{operation_id}' with Content-Type '{content_type_for_body_schema}'. Initial body is None.")
# 构造并返回APIRequestContext
return APIRequestContext(
method=method,
url=path_template, # 传递路径模板, e.g. /items/{itemId}
path_params=initial_path_params,
query_params=initial_query_params,
headers=initial_headers,
body=initial_body,
endpoint_spec=endpoint_spec # 传递原始的 endpoint_spec 字典
)
def _build_object_schema_for_params(self, params_spec_list: List[Dict[str, Any]], model_name_base: str) -> Tuple[Optional[Dict[str, Any]], str]:
"""
将参数列表 (如路径参数、查询参数列表) 转换为一个单一的 "type: object" JSON schema
以便用于创建 Pydantic 模型。
会尝试适配参数定义中缺少嵌套 'schema' 字段但有顶层 'type' 的情况。
"""
if not params_spec_list:
return None, model_name_base
properties = {}
required_params = []
parameter_names = []
for param_spec in params_spec_list:
param_name = param_spec.get("name")
if not param_name:
self.logger.warning(f"参数定义缺少 'name' 字段: {param_spec}。已跳过。")
continue
parameter_names.append(param_name)
param_schema = param_spec.get("schema")
# ---- 适配开始 ----
if not param_schema and param_spec.get("type"):
self.logger.debug(f"参数 '{param_name}' 缺少嵌套 'schema' 字段,尝试从顶层 'type' 构建临时schema。 Param spec: {param_spec}")
temp_schema = {"type": param_spec.get("type")}
# 从 param_spec 顶层提取其他相关字段到 temp_schema
for key in ["format", "default", "example", "description", "enum",
"minimum", "maximum", "minLength", "maxLength", "pattern",
"items"]: # items 用于处理顶层定义的array
if key in param_spec:
temp_schema[key] = param_spec[key]
param_schema = temp_schema
# ---- 适配结束 ----
if not param_schema: # 如果适配后仍然没有schema
self.logger.warning(f"参数 '{param_name}' 缺少 'schema' 定义且无法从顶层构建: {param_spec}。已跳过。")
continue
# 处理 $ref (简单情况假设ref在components.schemas)
# 更复杂的 $ref 解析可能需要访问完整的OpenAPI文档
if isinstance(param_schema, dict) and "$ref" in param_schema: # 确保 param_schema 是字典再检查 $ref
ref_path = param_schema["$ref"]
# 这是一个非常简化的$ref处理实际可能需要解析整个文档
self.logger.warning(f"参数 '{param_name}' 的 schema 包含 $ref '{ref_path}'当前不支持自动解析。请确保schema是内联的。")
# 可以尝试提供一个非常基础的schema或者跳过这个参数或者让_generate_data_from_schema处理
properties[param_name] = {"type": "string", "description": f"Reference to {ref_path}"}
elif isinstance(param_schema, dict): # 确保 param_schema 是字典
properties[param_name] = param_schema
else:
self.logger.warning(f"参数 '{param_name}' 的 schema 不是一个有效的字典: {param_schema}。已跳过。")
continue
if param_spec.get("required", False):
required_params.append(param_name)
if not properties: # 如果所有参数都无效
return None, model_name_base
model_name = f"{model_name_base}_{'_'.join(sorted(parameter_names))}" # 使模型名更具唯一性
object_schema = {
"type": "object",
"properties": properties,
}
if required_params:
object_schema["required"] = required_params
self.logger.debug(f"[{model_name_base}] 为参数集 {parameter_names} 构建的最终 Object Schema: {json.dumps(object_schema, indent=2)}, 模型名: {model_name}")
return object_schema, model_name
def _generate_params_from_list(self, params_spec_list: List[Dict[str, Any]], operation_id: str, param_type: str) -> Dict[str, Any]:
"""
遍历参数定义列表,使用 _generate_data_from_schema 为每个参数生成数据。
会尝试适配参数定义中缺少嵌套 'schema' 字段但有顶层 'type' 的情况。
"""
generated_params: Dict[str, Any] = {}
if not params_spec_list:
self.logger.info(f"[{operation_id}] 没有定义 {param_type} 参数。")
return generated_params
self.logger.info(f"[{operation_id}] 使用常规方法生成 {param_type} 参数。")
for param_spec in params_spec_list:
param_name = param_spec.get("name")
param_schema = param_spec.get("schema")
# ---- 适配开始 ----
if not param_schema and param_spec.get("type"):
self.logger.debug(f"参数 '{param_name}' ('{param_type}' 类型) 缺少嵌套 'schema' 字段,尝试从顶层 'type' 构建临时schema用于常规生成。 Param spec: {param_spec}")
temp_schema = {"type": param_spec.get("type")}
# 从 param_spec 顶层提取其他相关字段到 temp_schema
for key in ["format", "default", "example", "description", "enum",
"minimum", "maximum", "minLength", "maxLength", "pattern",
"items"]: # items 用于处理顶层定义的array
if key in param_spec:
temp_schema[key] = param_spec[key]
param_schema = temp_schema
# ---- 适配结束 ----
if param_name and param_schema and isinstance(param_schema, dict): # 确保param_schema是字典
generated_value = self._generate_data_from_schema(
param_schema,
context_name=f"{param_type} parameter '{param_name}'",
operation_id=operation_id
)
if generated_value is not None:
generated_params[param_name] = generated_value
elif param_spec.get("required"):
self.logger.warning(f"[{operation_id}] 未能为必需的 {param_type} 参数 '{param_name}' 生成数据 (schema: {param_schema}),且其 schema 中可能没有有效的默认值或示例。")
else:
self.logger.warning(f"[{operation_id}] 跳过无效的 {param_type} 参数定义 (名称: {param_name}, schema: {param_schema}): {param_spec}")
self.logger.info(f"[{operation_id}] 常规方法生成的 {param_type} 参数: {generated_params}")
return generated_params
def run_test_for_endpoint(self, endpoint: Union[YAPIEndpoint, SwaggerEndpoint, DMSEndpoint],
global_api_spec: Union[ParsedYAPISpec, ParsedSwaggerSpec, ParsedDMSSpec]
) -> TestResult:
# 检查是否为仅场景测试的端点,如果是则跳过独立测试
if isinstance(endpoint, DMSEndpoint) and hasattr(endpoint, 'test_mode') and endpoint.test_mode == 'scenario_only':
self.logger.info(f"跳过对仅场景测试端点的独立测试: {endpoint.method} {endpoint.path}")
result = TestResult(
endpoint_id=f"{endpoint.method}_{endpoint.path}",
endpoint_name=endpoint.title or f"{endpoint.method} {endpoint.path}"
)
result.overall_status = TestResult.Status.SKIPPED
result.message = "此端点标记为仅在场景中测试 (test_mode='scenario_only')"
return result
endpoint_id = f"{getattr(endpoint, 'method', 'GET').upper()} {getattr(endpoint, 'path', '/')}"
endpoint_name = getattr(endpoint, 'title', '') or getattr(endpoint, 'summary', '') or endpoint_id
self.logger.info(f"开始为端点测试: {endpoint_id} ({endpoint_name})")
endpoint_test_result = TestResult(
endpoint_id=endpoint_id,
endpoint_name=endpoint_name,
)
if not self.test_case_registry:
self.logger.warning(f"TestCaseRegistry 未初始化,无法为端点 '{endpoint_id}' 执行自定义测试用例。")
endpoint_test_result.overall_status = TestResult.Status.SKIPPED
endpoint_test_result.error_message = "TestCaseRegistry 未初始化。"
endpoint_test_result.finalize_endpoint_test(strictness_level=self.strictness_level)
return endpoint_test_result
applicable_test_case_classes_unordered = self.test_case_registry.get_applicable_test_cases(
endpoint_method=endpoint.method.upper(),
endpoint_path=endpoint.path
)
if not applicable_test_case_classes_unordered:
self.logger.info(f"端点 '{endpoint_id}' 没有找到适用的自定义测试用例。")
endpoint_test_result.finalize_endpoint_test(strictness_level=self.strictness_level) # 确保在返回前调用
return endpoint_test_result
# 根据 execution_order 排序测试用例
applicable_test_case_classes = sorted(
applicable_test_case_classes_unordered,
key=lambda tc_class: tc_class.execution_order
)
self.logger.info(f"端点 '{endpoint_id}' 发现了 {len(applicable_test_case_classes)} 个适用的测试用例 (已排序): {[tc.id for tc in applicable_test_case_classes]}")
critical_setup_test_failed = False
critical_setup_failure_reason = ""
for tc_class in applicable_test_case_classes:
start_single_tc_time = time.monotonic() # 用于计算跳过测试用例的持续时间
if critical_setup_test_failed:
self.logger.warning(f"由于关键的前置测试用例失败,跳过测试用例 '{tc_class.id}' for '{endpoint_id}'. 原因: {critical_setup_failure_reason}")
skipped_tc_duration = time.monotonic() - start_single_tc_time
executed_case_result = ExecutedTestCaseResult(
test_case_id=tc_class.id,
test_case_name=tc_class.name,
test_case_severity=tc_class.severity,
status=ExecutedTestCaseResult.Status.SKIPPED,
validation_points=[],
message=f"由于关键的前置测试失败而被跳过: {critical_setup_failure_reason}",
duration=skipped_tc_duration
)
else:
self.logger.debug(f"准备执行测试用例 '{tc_class.id}' for '{endpoint_id}'")
executed_case_result = self._execute_single_test_case(
test_case_class=tc_class,
endpoint_spec=endpoint,
global_api_spec=global_api_spec
)
# 检查是否是关键测试用例以及是否失败
if hasattr(tc_class, 'is_critical_setup_test') and tc_class.is_critical_setup_test:
if executed_case_result.status in [ExecutedTestCaseResult.Status.FAILED, ExecutedTestCaseResult.Status.ERROR]:
critical_setup_test_failed = True
critical_setup_failure_reason = f"关键测试 '{tc_class.id}' 失败 (状态: {executed_case_result.status.value})。消息: {executed_case_result.message}"
self.logger.error(f"关键的前置测试用例 '{tc_class.id}' for '{endpoint_id}' 失败。后续测试将被跳过。原因: {critical_setup_failure_reason}")
endpoint_test_result.add_executed_test_case_result(executed_case_result)
# 日志部分可以保持不变或根据需要调整
if executed_case_result.status.value == ExecutedTestCaseResult.Status.FAILED.value:
self.logger.debug(f"\033[91m ❌ 测试用例 '{tc_class.id}' 执行失败。\033[0m")
elif executed_case_result.status.value == ExecutedTestCaseResult.Status.PASSED.value :
self.logger.debug(f"\033[92m ✅ 测试用例 '{tc_class.id}' 执行成功。\033[0m")
# 对于SKIPPED和ERROR状态可以添加不同颜色的日志
elif executed_case_result.status.value == ExecutedTestCaseResult.Status.SKIPPED.value:
self.logger.debug(f"\033[93m ⏭️ 测试用例 '{tc_class.id}' 被跳过。\033[0m") # 黄色
elif executed_case_result.status.value == ExecutedTestCaseResult.Status.ERROR.value:
self.logger.debug(f"\033[91m 💥 测试用例 '{tc_class.id}' 执行时发生错误。\033[0m") # 红色 (与FAILED相同或不同)
self.logger.debug(f"测试用例 '{tc_class.id}' 执行完毕,状态: {executed_case_result.status.value}")
endpoint_test_result.finalize_endpoint_test(strictness_level=self.strictness_level)
self.logger.info(f"端点 '{endpoint_id}' 测试完成,最终状态: {endpoint_test_result.overall_status.value}")
return endpoint_test_result
def run_tests_from_yapi(self, yapi_file_path: str,
categories: Optional[List[str]] = None,
custom_test_cases_dir: Optional[str] = None
) -> Tuple[TestSummary, Optional[ParsedAPISpec]]:
self.logger.info(f"准备从YAPI文件运行测试用例: {yapi_file_path}")
self.global_api_call_details = [] # 为新的测试用例运行重置API调用日志
parsed_yapi = self.parser.parse_yapi_spec(yapi_file_path)
summary = TestSummary()
if not parsed_yapi:
self.logger.error(f"解析YAPI文件失败: {yapi_file_path}")
summary.finalize_summary() # 即使失败也最终化摘要
return summary, None
# 调用内部执行方法来执行测试用例
self._execute_tests_from_parsed_spec(
parsed_spec=parsed_yapi,
summary=summary,
categories=categories,
custom_test_cases_dir=custom_test_cases_dir
)
# finalize_summary 和 print_summary_to_console 将在 run_api_tests.py 中进行
return summary, parsed_yapi
def run_tests_from_swagger(self, swagger_file_path: str,
tags: Optional[List[str]] = None,
custom_test_cases_dir: Optional[str] = None
) -> Tuple[TestSummary, Optional[ParsedAPISpec]]:
self.logger.info(f"准备从Swagger文件运行测试用例: {swagger_file_path}")
self.global_api_call_details = [] # 为新的测试用例运行重置API调用日志
parsed_swagger = self.parser.parse_swagger_spec(swagger_file_path)
summary = TestSummary()
if not parsed_swagger:
self.logger.error(f"解析Swagger文件失败: {swagger_file_path}")
summary.finalize_summary()
return summary, None
# 调用内部执行方法来执行测试用例
self._execute_tests_from_parsed_spec(
parsed_spec=parsed_swagger,
summary=summary,
tags=tags,
custom_test_cases_dir=custom_test_cases_dir
)
# finalize_summary 和 print_summary_to_console 将在 run_api_tests.py 中进行
return summary, parsed_swagger
def _generate_data_from_schema(self, schema: Dict[str, Any],
context_name: Optional[str] = None,
operation_id: Optional[str] = None) -> Any:
"""
根据JSON Schema生成测试数据 (此方法基本保持不变,可能被测试用例或编排器内部使用)
增加了 context_name 和 operation_id 用于更详细的日志。
"""
log_prefix = f"[{operation_id}] " if operation_id else ""
context_log = f" (context: {context_name})" if context_name else ""
if not schema or not isinstance(schema, dict):
self.logger.debug(f"{log_prefix}_generate_data_from_schema: 提供的 schema 无效或为空{context_log}: {schema}")
return None
schema_type = schema.get('type')
if 'example' in schema:
self.logger.debug(f"{log_prefix}使用 schema 中的 'example' 值 for{context_log}: {schema['example']}")
return schema['example']
if 'default' in schema:
self.logger.debug(f"{log_prefix}使用 schema 中的 'default' 值 for{context_log}: {schema['default']}")
return schema['default']
# Handle both 'object' and 'Object' (case-insensitive)
if schema_type and schema_type.lower() == 'object':
result = {}
properties = schema.get('properties', {})
self.logger.debug(f"{log_prefix}生成 object 类型数据 for{context_log}. Properties: {list(properties.keys())}")
for prop_name, prop_schema in properties.items():
# 递归调用时传递上下文,但稍微修改一下 context_name
nested_context = f"{context_name}.{prop_name}" if context_name else prop_name
result[prop_name] = self._generate_data_from_schema(prop_schema, nested_context, operation_id)
return result if result else {}
# Handle both 'array' and 'Array' (case-insensitive)
elif schema_type and schema_type.lower() == 'array':
items_schema = schema.get('items', {})
min_items = schema.get('minItems', 1 if schema.get('default') is None and schema.get('example') is None else 0)
self.logger.debug(f"{log_prefix}生成 array 类型数据 for{context_log}. Items schema: {items_schema}, minItems: {min_items}")
if min_items == 0 and (schema.get('default') == [] or schema.get('example') == []):
return []
num_items_to_generate = max(1, min_items)
generated_array = []
for i in range(num_items_to_generate):
item_context = f"{context_name}[{i}]" if context_name else f"array_item[{i}]"
generated_array.append(self._generate_data_from_schema(items_schema, item_context, operation_id))
return generated_array
# Handle both 'string' and 'String' (case-insensitive)
elif schema_type and schema_type.lower() == 'string':
string_format = schema.get('format', '')
val = None
if 'enum' in schema and schema['enum']:
val = schema['enum'][0]
elif string_format == 'date': val = '2023-01-01'
elif string_format == 'date-time': val = datetime.datetime.now().isoformat()
elif string_format == 'email': val = 'test@example.com'
elif string_format == 'uuid': import uuid; val = str(uuid.uuid4())
else: val = 'example_string'
self.logger.debug(f"{log_prefix}生成 string 类型数据 ('{string_format}') for{context_log}: {val}")
return val
# Handle both 'number'/'Number' and 'integer'/'Integer' (case-insensitive)
elif (schema_type and schema_type.lower() == 'number') or (schema_type and schema_type.lower() == 'integer'):
val_to_return = schema.get('default', schema.get('example'))
if val_to_return is not None:
self.logger.debug(f"{log_prefix}使用 number/integer 的 default/example 值 for{context_log}: {val_to_return}")
return val_to_return
minimum = schema.get('minimum')
# maximum = schema.get('maximum') # Not used yet for generation, but could be
if minimum is not None:
val_to_return = minimum
else:
val_to_return = 0 if schema_type.lower() == 'integer' else 0.0
self.logger.debug(f"{log_prefix}生成 number/integer 类型数据 for{context_log}: {val_to_return}")
return val_to_return
# Handle both 'boolean' and 'Boolean' (case-insensitive)
elif schema_type and schema_type.lower() == 'boolean':
val = schema.get('default', schema.get('example', False))
self.logger.debug(f"{log_prefix}生成 boolean 类型数据 for{context_log}: {val}")
return val
elif schema_type == 'null':
self.logger.debug(f"{log_prefix}生成 null 类型数据 for{context_log}")
return None
self.logger.debug(f"{log_prefix}_generate_data_from_schema: 未知或不支持的 schema 类型 '{schema_type}' for{context_log}. Schema: {schema}")
return None
def _format_url_with_path_params(self, path_template: str, path_params: Dict[str, Any]) -> str:
"""
使用提供的路径参数格式化URL路径模板。
例如: path_template='/users/{userId}/items/{itemId}', path_params={'userId': 123, 'itemId': 'abc'}
会返回 '/users/123/items/abc'
同时处理 base_url.
"""
# 首先确保 path_template 不以 '/' 开头,如果 self.base_url 已经以 '/' 结尾
# 或者确保它们之间只有一个 '/'
formatted_path = path_template
for key, value in path_params.items():
placeholder = f"{{{key}}}"
if placeholder in formatted_path:
formatted_path = formatted_path.replace(placeholder, str(value))
else:
self.logger.warning(f"路径参数 '{key}' 在路径模板 '{path_template}' 中未找到占位符。")
# 拼接 base_url 和格式化后的路径
# 确保 base_url 和 path 之间只有一个斜杠
if self.base_url.endswith('/') and formatted_path.startswith('/'):
url = self.base_url + formatted_path[1:]
elif not self.base_url.endswith('/') and not formatted_path.startswith('/'):
if formatted_path: # 避免在 base_url 后添加不必要的 '/' (如果 formatted_path 为空)
url = self.base_url + '/' + formatted_path
else:
url = self.base_url
else:
url = self.base_url + formatted_path
return url
def _resolve_json_schema_references(self, schema_to_resolve: Any, full_api_spec: Dict[str, Any], max_depth=10, current_depth=0) -> Any:
"""
递归解析JSON Schema中的$ref引用。
Args:
schema_to_resolve: 当前需要解析的schema部分 (可以是字典、列表或基本类型)。
full_api_spec: 完整的API规范字典用于查找$ref路径。
max_depth: 最大递归深度,防止无限循环。
current_depth: 当前递归深度。
Returns:
解析了$ref的schema部分。
"""
if current_depth > max_depth:
self.logger.warning(f"达到最大$ref解析深度 ({max_depth}),可能存在循环引用。停止进一步解析。")
return schema_to_resolve
if isinstance(schema_to_resolve, dict):
if "$ref" in schema_to_resolve:
ref_path = schema_to_resolve["$ref"]
if not isinstance(ref_path, str) or not ref_path.startswith("#/"):
self.logger.warning(f"不支持的$ref格式或外部引用: {ref_path}。仅支持本地引用 (e.g., #/components/schemas/MyModel)。")
return schema_to_resolve # 或者根据需要返回错误/None
path_parts = ref_path[2:].split('/') # Remove '#/' and split
resolved_component = full_api_spec
try:
for part in path_parts:
if isinstance(resolved_component, list): # Handle paths like #/components/parameters/0
part = int(part)
resolved_component = resolved_component[part]
# 递归解析引用过来的组件,以处理嵌套的$ref
# 同时传递原始$ref携带的其他属性如description, nullable等可以覆盖引用的内容
# See: https://json-schema.org/understanding-json-schema/structuring.html#merging
# For simplicity here, we prioritize the resolved component, but a more robust solution
# would merge properties from the $ref object itself with the resolved one.
# Create a copy of the resolved component to avoid modifying the original spec
# and to allow merging of sibling keywords if any.
component_copy = copy.deepcopy(resolved_component)
# Merge sibling keywords from the $ref object into the resolved component.
# Keywords in the $ref object override those in the referenced schema.
merged_schema = component_copy
if isinstance(component_copy, dict): # Ensure it's a dict before trying to update
for key, value in schema_to_resolve.items():
if key != "$ref":
merged_schema[key] = value # Override or add
self.logger.debug(f"成功解析并合并 $ref: '{ref_path}'。正在递归解析其内容。")
return self._resolve_json_schema_references(merged_schema, full_api_spec, max_depth, current_depth + 1)
except (KeyError, IndexError, TypeError, ValueError) as e:
self.logger.error(f"解析$ref '{ref_path}' 失败: {e}.路径部分: {path_parts}. 当前组件类型: {type(resolved_component)}", exc_info=True)
return schema_to_resolve # 返回原始的$ref对象或错误指示
# 如果不是$ref则递归处理字典中的每个值
# 使用copy避免在迭代时修改字典
resolved_dict = {}
for key, value in schema_to_resolve.items():
resolved_dict[key] = self._resolve_json_schema_references(value, full_api_spec, max_depth, current_depth + 1)
return resolved_dict
elif isinstance(schema_to_resolve, list):
# 递归处理列表中的每个元素
return [self._resolve_json_schema_references(item, full_api_spec, max_depth, current_depth + 1) for item in schema_to_resolve]
else:
# 基本类型 (string, number, boolean, null) 不需要解析
return schema_to_resolve
def _util_find_removable_field_path_recursive(self, current_schema: Dict[str, Any], current_path: List[str], full_api_spec_for_refs: Dict[str, Any]) -> Optional[List[Union[str, int]]]:
"""
(框架辅助方法) 递归查找第一个可移除的必填字段的路径。
此方法现在需要 full_api_spec_for_refs 以便在需要时解析 $ref。
"""
# 首先解析当前 schema以防它是 $ref
resolved_schema = self._resolve_json_schema_references(current_schema, full_api_spec_for_refs)
if not isinstance(resolved_schema, dict) or resolved_schema.get("type") != "object":
return None
required_fields_at_current_level = resolved_schema.get("required", [])
properties = resolved_schema.get("properties", {})
self.logger.debug(f"[Util] 递归查找路径: {current_path}, 当前层级必填字段: {required_fields_at_current_level}, 属性: {list(properties.keys())}")
# 策略1: 查找当前层级直接声明的必填字段
if required_fields_at_current_level and properties:
for field_name in required_fields_at_current_level:
if field_name in properties:
self.logger.info(f"[Util] 策略1: 在路径 {'.'.join(map(str,current_path)) if current_path else 'root'} 找到可直接移除的必填字段: '{field_name}'")
return current_path + [field_name]
# 策略2: 查找数组属性看其内部item是否有必填字段
if properties:
for prop_name, prop_schema_orig in properties.items():
prop_schema = self._resolve_json_schema_references(prop_schema_orig, full_api_spec_for_refs)
if isinstance(prop_schema, dict) and prop_schema.get("type") == "array":
items_schema_orig = prop_schema.get("items")
if isinstance(items_schema_orig, dict):
items_schema = self._resolve_json_schema_references(items_schema_orig, full_api_spec_for_refs)
if isinstance(items_schema, dict) and items_schema.get("type") == "object":
item_required_fields = items_schema.get("required", [])
item_properties = items_schema.get("properties", {})
if item_required_fields and item_properties:
first_required_field_in_item = next((rf for rf in item_required_fields if rf in item_properties), None)
if first_required_field_in_item:
self.logger.info(f"[Util] 策略2: 在数组属性 '{prop_name}' (路径 {'.'.join(map(str,current_path)) if current_path else 'root'}) 的元素内找到必填字段: '{first_required_field_in_item}'. 路径: {current_path + [prop_name, 0, first_required_field_in_item]}")
return current_path + [prop_name, 0, first_required_field_in_item]
# 策略3: 递归到子对象中查找(可选,但对于通用工具可能有用)
# 注意:这可能会找到非顶层必填对象内部的必填字段。
# if properties:
# for prop_name, prop_schema_orig_for_recurse in properties.items():
# prop_schema_for_recurse = self._resolve_json_schema_references(prop_schema_orig_for_recurse, full_api_spec_for_refs)
# if isinstance(prop_schema_for_recurse, dict) and prop_schema_for_recurse.get("type") == "object":
# # Avoid re-checking fields already covered by strategy 1 if they were required at this level
# # if prop_name not in required_fields_at_current_level:
# self.logger.debug(f"[Util] 策略3: 尝试递归进入对象属性 '{prop_name}' (路径 {current_path})")
# found_path_deeper = self._util_find_removable_field_path_recursive(prop_schema_for_recurse, current_path + [prop_name], full_api_spec_for_refs)
# if found_path_deeper:
# return found_path_deeper
self.logger.debug(f"[Util] 在路径 {'.'.join(map(str,current_path)) if current_path else 'root'} 未通过任何策略找到可移除的必填字段。")
return None
def _util_remove_value_at_path(self, data_container: Any, path: List[Union[str, int]]) -> Tuple[Any, Any, bool]:
"""
(框架辅助方法) 从嵌套的字典/列表中移除指定路径的值。
返回 (修改后的容器, 被移除的值, 是否成功)。
"""
if not path:
self.logger.error("[Util] _util_remove_value_at_path: 路径不能为空。")
return data_container, None, False
# 深拷贝以避免修改原始数据,除非调用者期望如此
# 如果 data_container 是 None 且路径非空,则尝试构建最小结构
if data_container is None:
if isinstance(path[0], str): # 路径以字段名开始,期望字典
container_copy = {}
elif isinstance(path[0], int): # 路径以索引开始,期望列表
container_copy = []
else:
self.logger.error(f"[Util] _util_remove_value_at_path: 路径的第一个元素 '{path[0]}' 类型未知。")
return data_container, None, False
else:
container_copy = copy.deepcopy(data_container)
current_level = container_copy
original_value = None
try:
for i, key_or_index in enumerate(path):
is_last_element = (i == len(path) - 1)
if is_last_element:
if isinstance(key_or_index, str): # Key for a dictionary (field name)
if isinstance(current_level, dict) and key_or_index in current_level:
original_value = current_level.pop(key_or_index)
self.logger.info(f"[Util] 从路径 '{'.'.join(map(str,path))}' 成功移除字段 '{key_or_index}' (原值: '{original_value}')。")
return container_copy, original_value, True
elif isinstance(current_level, dict):
self.logger.warning(f"[Util] 路径的最后一部分 '{key_or_index}' (string key) 在对象中未找到。路径: {'.'.join(map(str,path))}")
return container_copy, None, False # 字段不存在,但结构符合
else:
self.logger.error(f"[Util] 路径的最后一部分 '{key_or_index}' (string key) 期望父级是字典,但找到 {type(current_level)}。路径: {'.'.join(map(str,path))}")
return data_container, None, False # 结构不符,返回原始数据
else: # Last element of path is an index - this indicates removing an item from a list
if isinstance(current_level, list) and isinstance(key_or_index, int) and 0 <= key_or_index < len(current_level):
original_value = current_level.pop(key_or_index)
self.logger.info(f"[Util] 从路径 '{'.'.join(map(str,path))}' 成功移除索引 '{key_or_index}' 的元素 (原值: '{original_value}')。")
return container_copy, original_value, True
elif isinstance(current_level, list):
self.logger.warning(f"[Util] 路径的最后一部分索引 '{key_or_index}' 超出列表范围或类型不符。列表长度: {len(current_level)}. 路径: {'.'.join(map(str,path))}")
return container_copy, None, False # 索引无效,但结构符合
else:
self.logger.error(f"[Util] 路径的最后一部分 '{key_or_index}' 期望父级是列表,但找到 {type(current_level)}。路径: {'.'.join(map(str,path))}")
return data_container, None, False # 结构不符
else: # Not the last element, so we are traversing or building the structure
next_key_or_index = path[i+1]
if isinstance(key_or_index, str): # Current path part is a dictionary key
if not isinstance(current_level, dict):
self.logger.debug(f"[Util] 路径期望字典,但在 '{key_or_index}' (父级)处找到 {type(current_level)}. 将创建空字典。")
# This should only happen if current_level was initially part of a None container and we're building it up.
# If current_level is not a dict and it's not the root being built, it's an error.
# For robust path creation from None:
if current_level is container_copy and not container_copy : # building from scratch
current_level = {} # This change needs to be reflected in container_copy if this is the root
if i == 0: container_copy = current_level
else: # This case is complex: how to link back if not root?
self.logger.error(f"[Util] 无法在非根级别从非字典创建路径。")
return data_container, None, False
else: # Path expects dict, but found something else not at root.
self.logger.error(f"[Util] 路径期望字典,但在 '{key_or_index}' 处找到 {type(current_level)}")
return data_container, None, False
# Ensure the next level exists and is of the correct type
if isinstance(next_key_or_index, int): # Next is an array index
if key_or_index not in current_level or not isinstance(current_level.get(key_or_index), list):
self.logger.debug(f"[Util] 路径 '{key_or_index}' 下需要列表 (为索引 {next_key_or_index} 做准备),将创建空列表。")
current_level[key_or_index] = []
current_level = current_level[key_or_index]
else: # Next is a dictionary key
if key_or_index not in current_level or not isinstance(current_level.get(key_or_index), dict):
self.logger.debug(f"[Util] 路径 '{key_or_index}' 下需要字典 (为键 '{next_key_or_index}' 做准备),将创建空字典。")
current_level[key_or_index] = {}
current_level = current_level[key_or_index]
elif isinstance(key_or_index, int): # Current path part is an array index
if not isinstance(current_level, list):
self.logger.error(f"[Util] 路径期望列表以应用索引 '{key_or_index}',但找到 {type(current_level)}")
return data_container, None, False
# Ensure the list is long enough, fill with dict/list based on next path element
while len(current_level) <= key_or_index:
if isinstance(next_key_or_index, str): # Next is a dict key
self.logger.debug(f"[Util] 数组在索引 {key_or_index} 处需要元素,将添加空字典。")
current_level.append({})
else: # Next is an array index
self.logger.debug(f"[Util] 数组在索引 {key_or_index} 处需要元素,将添加空列表。")
current_level.append([])
# Ensure the element at index is of the correct type for the next key/index
if isinstance(next_key_or_index, str): # Next is a dict key
if not isinstance(current_level[key_or_index], dict):
self.logger.debug(f"[Util] 数组项 at index {key_or_index} 需要是字典。将被替换。")
current_level[key_or_index] = {}
elif isinstance(next_key_or_index, int): # Next is an array index
if not isinstance(current_level[key_or_index], list):
self.logger.debug(f"[Util] 数组项 at index {key_or_index} 需要是列表。将被替换。")
current_level[key_or_index] = []
current_level = current_level[key_or_index]
else:
self.logger.error(f"[Util] 路径部分 '{key_or_index}' 类型未知 ({type(key_or_index)}).")
return data_container, None, False
except Exception as e:
self.logger.error(f"[Util] 在准备移除字段路径 '{'.'.join(map(str,path))}' 时发生错误: {e}", exc_info=True)
return data_container, None, False
self.logger.error(f"[Util] _util_remove_value_at_path 未能在循环内按预期返回。路径: {'.'.join(map(str,path))}")
return data_container, None, False
# --- 新增:场景测试执行相关方法 ---
def _resolve_value_from_context_or_literal(self, value_template: Any, stage_context: Dict[str, Any], step_name_for_log: str) -> Any:
"""
解析一个值,如果它是字符串且符合 {{stage_context.变量}} 格式,则从阶段上下文中取值,否则直接返回值。
支持从字典和列表中深入取值,例如 {{stage_context.user.id}} 或 {{stage_context.items[0].name}}。
"""
if isinstance(value_template, str):
match = re.fullmatch(r"\{\{\s*stage_context\.([a-zA-Z0-9_\.\[\]]+)\s*\}\}", value_template)
if match:
path_expression = match.group(1)
self.logger.debug(f"[阶段步骤 '{step_name_for_log}'] 解析上下文路径: '{path_expression}' 来自模板 '{value_template}'")
try:
current_value = stage_context
# Improved path splitting to handle array indices correctly
# e.g. items[0].id -> ["items", "0", "id"]
parts = re.split(r'\.(?![^\[]*\])|\[|\]', path_expression)
parts = [p for p in parts if p] # Remove empty strings from split, e.g. from "[0]"
for part in parts:
if isinstance(current_value, dict):
current_value = current_value[part]
elif isinstance(current_value, list) and part.isdigit():
current_value = current_value[int(part)]
else:
raise KeyError(f"路径部分 '{part}' 无法从当前值类型 {type(current_value)} (路径: {path_expression}) 中解析")
self.logger.info(f"[测试阶段步骤 '{step_name_for_log}'] 从上下文解析到值 '{str(current_value)[:100]}...' (路径: '{path_expression}')")
return current_value
except (KeyError, IndexError, TypeError) as e:
self.logger.error(f"[测试阶段步骤 '{step_name_for_log}'] 从阶段上下文解析路径 '{path_expression}' 失败: {e}", exc_info=False) # exc_info=False to reduce noise for expected failures
return None # Return None or original template to indicate failure
return value_template # Not a valid placeholder, return original string
elif isinstance(value_template, list):
return [self._resolve_value_from_context_or_literal(item, stage_context, step_name_for_log) for item in value_template]
elif isinstance(value_template, dict):
return {k: self._resolve_value_from_context_or_literal(v, stage_context, step_name_for_log) for k, v in value_template.items()}
else:
return value_template # Other types directly return
def _extract_outputs_to_context(self, response_data: Dict[str, Any], outputs_map: Dict[str, str], stage_context: Dict[str, Any], step_name_for_log: str):
"""
根据 outputs_map 从API响应中提取值并存入 stage_context。
Args:
response_data: 包含 'json_content', 'headers', 'status_code' 的字典。
outputs_map: 定义如何提取的字典,例如 {"user_id": "body.data.id", "token": "header.X-Auth-Token"}。
支持 "body.", "header.", "status_code" 作为路径前缀。
stage_context: 要更新的阶段上下文。
step_name_for_log: 当前步骤名称,用于日志。
"""
if not outputs_map or response_data is None:
return
for context_var_name, extraction_path in outputs_map.items():
self.logger.debug(f"[{step_name_for_log}] 尝试提取 '{extraction_path}' 到上下文变量 '{context_var_name}'")
value_to_extract = None
try:
path_parts = extraction_path.split('.')
source_type = path_parts[0].lower()
actual_path_parts = path_parts[1:]
target_obj = None
if source_type == "body":
target_obj = response_data.get('json_content')
elif source_type == "header":
target_obj = response_data.get('headers')
# Header keys are case-insensitive, normalize target_obj keys for lookup if it's header
if target_obj and actual_path_parts:
normalized_headers = {k.lower(): v for k, v in target_obj.items()}
# Check if the first part of actual_path is a key in normalized_headers
# This is for direct header access like header.X-Some-Token
if actual_path_parts[0].lower() in normalized_headers:
target_obj = normalized_headers
actual_path_parts[0] = actual_path_parts[0].lower() # Ensure lookup is also lower
# If not a direct key, it might be a deeper path within a structured header (less common)
# For simplicity, we'll assume direct access for now unless target_obj is not dict.
elif source_type == "status_code":
if not actual_path_parts:
value_to_extract = response_data.get('status_code')
stage_context[context_var_name] = value_to_extract
self.logger.info(f"[{step_name_for_log}] 提取到 '{context_var_name}': {value_to_extract}")
continue
else:
self.logger.warning(f"[{step_name_for_log}] status_code 不支持进一步的路径提取: '{extraction_path}'")
continue
else:
self.logger.warning(f"[{step_name_for_log}] 未知的提取源类型 '{source_type}' in path '{extraction_path}'")
continue
if target_obj is None and source_type != "status_code":
self.logger.warning(f"[{step_name_for_log}] 提取源 '{source_type}' 为空或不存在。无法提取 '{extraction_path}'")
continue
temp_val = target_obj
for part_idx, part in enumerate(actual_path_parts):
if isinstance(temp_val, dict):
# For headers, we might have already normalized keys. If not, try original then lower.
if source_type == "header" and part.lower() in temp_val: # Already normalized
temp_val = temp_val.get(part.lower())
elif part in temp_val: # Original case or body
temp_val = temp_val.get(part)
elif source_type == "header" and part.lower() not in temp_val: # Final attempt for header after initial direct check failed
self.logger.debug(f"[{step_name_for_log}] Header key '{part}' (path: '{extraction_path}') not found directly, checking case-insensitively if not already done.")
# This part might be redundant if the earlier normalization handled it for direct keys.
# This mainly helps if the initial actual_path_parts[0] wasn't the direct key but a deeper one.
found_case_insensitive = False
for k_header, v_header in temp_val.items():
if k_header.lower() == part.lower():
temp_val = v_header
found_case_insensitive = True
break
if not found_case_insensitive:
temp_val = None; break
else: # Not found in dict
temp_val = None; break
elif isinstance(temp_val, list) and part.isdigit():
idx = int(part)
if 0 <= idx < len(temp_val):
temp_val = temp_val[idx]
else:
self.logger.warning(f"[{step_name_for_log}] 数组索引 '{idx}' (来自部分 '{part}') 超出范围 for '{extraction_path}'. 列表长度: {len(temp_val)}.")
temp_val = None; break
else:
self.logger.warning(f"[{step_name_for_log}] 无法从类型 {type(temp_val)} (路径 '{'.'.join(actual_path_parts[:part_idx+1])}') 中提取部分 '{part}'. 路径: '{extraction_path}'")
temp_val = None; break
if temp_val is None: break
value_to_extract = temp_val
if value_to_extract is not None:
stage_context[context_var_name] = value_to_extract
self.logger.info(f"[{step_name_for_log}] 提取到上下文 '{context_var_name}': {str(value_to_extract)[:100]}...")
else:
self.logger.warning(f"[{step_name_for_log}] 未能从路径 '{extraction_path}' 提取到值。最终 temp_val 为 None.")
except Exception as e:
self.logger.error(f"[{step_name_for_log}] 从路径 '{extraction_path}' 提取值时出错: {e}", exc_info=True)
def execute_single_stage(
self,
stage_instance: BaseAPIStage,
parsed_spec: ParsedAPISpec, # This is the global ParsedAPISpec object
api_group_name: Optional[str] # Name of the current YAPI category or Swagger tag, or None if global
) -> ExecutedStageResult:
stage_start_time = datetime.datetime.now()
stage_context: Dict[str, Any] = {}
executed_steps_results: List[ExecutedStageStepResult] = []
stage_id = stage_instance.id
stage_name = stage_instance.name
self.logger.info(f"开始执行测试阶段: ID='{stage_id}', Name='{stage_name}', API分组='{api_group_name or 'Global'}'")
stage_result = ExecutedStageResult(
stage_id=stage_id,
stage_name=stage_name,
description=stage_instance.description,
api_group_metadata=stage_instance.current_api_group_metadata,
apis_in_group=stage_instance.apis_in_group, # 传递端点对象列表
tags=stage_instance.tags
)
try:
self.logger.debug(f"调用 stage '{stage_id}' 的 before_stage 钩子。初始上下文: {stage_context}")
stage_instance.before_stage(stage_context=stage_context, global_api_spec=parsed_spec, api_group_name=api_group_name)
except Exception as e_bs:
self.logger.error(f"Stage '{stage_id}' 的 before_stage 钩子执行失败: {e_bs}", exc_info=True)
stage_result.overall_status = ExecutedStageResult.Status.ERROR
stage_result.message = f"before_stage hook failed: {e_bs}"
stage_result.finalize_stage_result(final_context=stage_context)
return stage_result
for step_index, step_definition in enumerate(stage_instance.steps):
step_start_time = datetime.datetime.now()
current_step_name = step_definition.name or f"Step {step_index + 1}"
step_log_prefix = f"Stage '{stage_id}', Step '{current_step_name}'"
self.logger.info(f"{step_log_prefix}: 开始执行.")
current_step_result = ExecutedStageStepResult(
step_name=current_step_name,
description=step_definition.description,
lookup_key=step_definition.endpoint_spec_lookup_key,
status=ExecutedStageStepResult.Status.PENDING # Initial status
)
try:
self.logger.debug(f"{step_log_prefix}: 调用 before_step 钩子. 上下文: {stage_context}")
stage_instance.before_step(step=step_definition, stage_context=stage_context, global_api_spec=parsed_spec, api_group_name=api_group_name)
self.logger.debug(f"{step_log_prefix}: 查找端点定义. Key='{step_definition.endpoint_spec_lookup_key}', Group='{api_group_name}'")
api_op_spec: Optional[Union[APIOperationSpec, BaseEndpoint]] = stage_instance.get_api_spec_for_operation(
lookup_key=step_definition.endpoint_spec_lookup_key,
global_api_spec=parsed_spec,
api_group_name=api_group_name
)
actual_endpoint_spec_dict = None
endpoint_path = "N/A"
endpoint_method = "N/A"
if isinstance(api_op_spec, BaseEndpoint):
actual_endpoint_spec_dict = api_op_spec.to_dict()
endpoint_path = api_op_spec.path
endpoint_method = api_op_spec.method
elif hasattr(api_op_spec, 'spec') and api_op_spec.spec:
actual_endpoint_spec_dict = api_op_spec.spec
endpoint_path = getattr(api_op_spec, 'path', 'N/A')
endpoint_method = getattr(api_op_spec, 'method', 'N/A')
if not api_op_spec or not actual_endpoint_spec_dict:
current_step_result.status = ExecutedStageStepResult.Status.ERROR
current_step_result.message = f"找不到端点定义 (Key: '{step_definition.endpoint_spec_lookup_key}', Group: '{api_group_name}')."
self.logger.error(f"{step_log_prefix}: {current_step_result.message}")
else:
current_step_result.resolved_endpoint = f"{endpoint_method.upper()} {endpoint_path}"
current_step_result.status = ExecutedStageStepResult.Status.PASSED # Assume pass initially
self.logger.info(f"{step_log_prefix}: 已解析端点 '{current_step_result.resolved_endpoint}'. 准备请求数据.")
base_request_context: APIRequestContext = self._prepare_initial_request_data(actual_endpoint_spec_dict, None)
final_path_params = copy.deepcopy(base_request_context.path_params)
final_query_params = copy.deepcopy(base_request_context.query_params)
final_headers = copy.deepcopy(base_request_context.headers)
final_body = base_request_context.body # This could be None or generated data
# Apply overrides, resolving context variables
if step_definition.request_overrides:
self.logger.debug(f"{step_log_prefix}: 应用请求覆盖: {step_definition.request_overrides}")
for key, value_template in step_definition.request_overrides.items():
resolved_value = self._resolve_value_from_context_or_literal(value_template, stage_context, current_step_name)
if key == "path_params" and isinstance(resolved_value, dict): final_path_params.update(resolved_value)
elif key == "query_params" and isinstance(resolved_value, dict): final_query_params.update(resolved_value)
elif key == "headers" and isinstance(resolved_value, dict): final_headers.update(resolved_value)
elif key == "body": final_body = resolved_value # Override entire body
else: self.logger.warning(f"{step_log_prefix}: 未知的请求覆盖键 '{key}' 或类型不匹配.")
# Default Content-Type for JSON body if body exists and Content-Type not set
if final_body is not None and isinstance(final_body, (dict, list)) and not any(h.lower() == 'content-type' for h in final_headers.keys()):
final_headers['Content-Type'] = 'application/json'
self.logger.debug(f"{step_log_prefix}: 为JSON请求体设置默认Content-Type: application/json")
full_request_url = urljoin(self.base_url, self._format_url_with_path_params(endpoint_path, final_path_params))
api_request_obj = APIRequest(
method=endpoint_method,
url=full_request_url,
params=final_query_params,
headers=final_headers,
body=final_body # APIRequest handles if body is str/dict/list for json_data
)
current_step_result.request_details = api_request_obj.model_dump(mode='json')
self.logger.info(f"{step_log_prefix}: 发送API请求: {api_request_obj.method} {api_request_obj.url}")
api_response_obj, api_call_detail_obj = self.api_caller.call_api(api_request_obj)
self.global_api_call_details.append(api_call_detail_obj) # Log globally
current_step_result.api_call_details = api_call_detail_obj.model_dump(mode='json')
self.logger.info(f"{step_log_prefix}: 收到响应. Status: {api_response_obj.status_code}. 验证中...")
step_validation_points: List[ValidationResult] = []
# Status Code Check
if step_definition.expected_status_codes and api_response_obj.status_code not in step_definition.expected_status_codes:
msg = f"预期状态码为 {step_definition.expected_status_codes}, 实际为 {api_response_obj.status_code}."
step_validation_points.append(ValidationResult(passed=False, message=msg, details={"expected": step_definition.expected_status_codes, "actual": api_response_obj.status_code}))
current_step_result.status = ExecutedStageStepResult.Status.FAILED
self.logger.warning(f"{step_log_prefix}: {msg}")
elif step_definition.expected_status_codes: # If codes are defined and it passed
step_validation_points.append(ValidationResult(passed=True, message=f"状态码匹配 ({api_response_obj.status_code})."))
# Response Assertions
request_context_for_assertion = APIRequestContext(
method=api_request_obj.method, url=str(api_request_obj.url),
path_params=final_path_params, query_params=api_request_obj.params,
headers=api_request_obj.headers, body=api_request_obj.json_data,
endpoint_spec=actual_endpoint_spec_dict
)
response_context_for_assertion = APIResponseContext(
status_code=api_response_obj.status_code, headers=api_response_obj.headers,
json_content=api_response_obj.json_content,
text_content=api_response_obj.content.decode('utf-8', errors='replace') if api_response_obj.content else None, # 修复:解码 bytes 内容
elapsed_time=api_response_obj.elapsed_time, original_response=getattr(api_response_obj, 'raw_response', None),
request_context=request_context_for_assertion
)
# Use a separate list for this step's validation points
step_validation_results: List[ValidationResult] = []
# Assertions are callables that return ValidationResult
for assertion in step_definition.response_assertions:
try:
validation_result = assertion(response_context_for_assertion, stage_context)
step_validation_results.append(validation_result)
except Exception as e_assert:
self.logger.error(f"{step_log_prefix}: Assertion function '{getattr(assertion, '__name__', 'N/A')}' raised an exception: {e_assert}", exc_info=True)
failed_vr = ValidationResult(passed=False, message=f"Assertion function raised an unhandled exception: {e_assert}")
step_validation_results.append(failed_vr)
# Check status codes
if step_definition.expected_status_codes:
status_code_vr = self._validate_status_code(
actual_code=response_context_for_assertion.status_code,
expected_codes=step_definition.expected_status_codes
)
step_validation_results.append(status_code_vr)
current_step_result.validation_points = [vp.to_dict() for vp in step_validation_results]
if any(not vp['passed'] for vp in current_step_result.validation_points):
current_step_result.status = ExecutedStageStepResult.Status.FAILED
# Store API call details
if api_call_detail_obj:
if hasattr(api_call_detail_obj, 'model_dump') and callable(api_call_detail_obj.model_dump):
current_step_result.api_call_details = api_call_detail_obj.model_dump()
elif hasattr(api_call_detail_obj, 'dict') and callable(api_call_detail_obj.dict):
current_step_result.api_call_details = api_call_detail_obj.dict()
else:
# Fallback if it's some other object, though it should be APICallDetail
current_step_result.api_call_details = str(api_call_detail_obj)
else:
current_step_result.api_call_details = {}
self.logger.debug(f"{step_log_prefix}: 提取输出到上下文. Map: {step_definition.outputs_to_context}")
self._extract_outputs_to_context(
response_data={
"json_content": api_response_obj.json_content,
"headers": api_response_obj.headers,
"status_code": api_response_obj.status_code
},
outputs_map=step_definition.outputs_to_context,
stage_context=stage_context,
step_name_for_log=current_step_name
)
current_step_result.context_after_step = copy.deepcopy(stage_context)
except Exception as step_exec_exc:
current_step_result.status = ExecutedStageStepResult.Status.ERROR
err_msg = f"步骤执行期间发生意外错误: {step_exec_exc}"
current_step_result.message = f"{current_step_result.message or ''} | {err_msg}".strip(" | ")
self.logger.error(f"{step_log_prefix}: {err_msg}", exc_info=True)
finally:
current_step_result.duration_seconds = (datetime.datetime.now() - step_start_time).total_seconds()
current_step_result.finalize_step_result() # Consolidate messages from validation points if main message is empty
self.logger.debug(f"{step_log_prefix}: 调用 after_step 钩子.")
try:
stage_instance.after_step(
step=step_definition,
step_result=current_step_result,
stage_context=stage_context,
global_api_spec=parsed_spec,
api_group_name=api_group_name
)
except Exception as e_as:
self.logger.error(f"{step_log_prefix}: after_step 钩子执行失败: {e_as}", exc_info=True)
if current_step_result.status == ExecutedStageStepResult.Status.PASSED:
current_step_result.status = ExecutedStageStepResult.Status.ERROR # Downgrade to error if after_step fails a passed step
# Append error to message regardless of original status if after_step fails
current_step_result.message = f"{current_step_result.message or ''} | after_step hook failed: {e_as}".strip(" | ")
executed_steps_results.append(current_step_result)
self.logger.info(f"{step_log_prefix}: 执行完毕. 状态: {current_step_result.status.value}, 耗时: {current_step_result.duration_seconds:.2f}s")
if current_step_result.status != ExecutedStageStepResult.Status.PASSED and not stage_instance.continue_on_failure:
self.logger.warning(f"{step_log_prefix}: 状态为 {current_step_result.status.value} 且 continue_on_failure=False. 中止测试阶段 '{stage_id}'.")
# Update stage_result's overall status pre-emptively if aborting
if current_step_result.status == ExecutedStageStepResult.Status.ERROR:
stage_result.overall_status = ExecutedStageResult.Status.ERROR
elif current_step_result.status == ExecutedStageStepResult.Status.FAILED and stage_result.overall_status != ExecutedStageResult.Status.ERROR:
stage_result.overall_status = ExecutedStageResult.Status.FAILED
stage_result.message = stage_result.message or f"测试阶段因步骤 '{current_step_name}' 的状态 ({current_step_result.status.value}) 而中止."
break # Abort stage execution
# Determine final stage status based on all executed steps
if stage_result.overall_status == ExecutedStageResult.Status.PENDING: # If not set by before_stage error or early abort
if not executed_steps_results and stage_instance.steps:
stage_result.overall_status = ExecutedStageResult.Status.SKIPPED # Or ERROR if this is unexpected
stage_result.message = stage_result.message or "阶段已定义步骤但未执行任何步骤."
elif not stage_instance.steps:
stage_result.overall_status = ExecutedStageResult.Status.PASSED # No steps, before_stage OK
elif any(s.status == ExecutedStageStepResult.Status.ERROR for s in executed_steps_results):
stage_result.overall_status = ExecutedStageResult.Status.ERROR
elif any(s.status == ExecutedStageStepResult.Status.FAILED for s in executed_steps_results):
stage_result.overall_status = ExecutedStageResult.Status.FAILED
elif all(s.status == ExecutedStageStepResult.Status.PASSED for s in executed_steps_results if executed_steps_results):
stage_result.overall_status = ExecutedStageResult.Status.PASSED
elif all(s.status == ExecutedStageStepResult.Status.SKIPPED for s in executed_steps_results if executed_steps_results):
stage_result.overall_status = ExecutedStageResult.Status.SKIPPED
else: # Mix of PASSED, SKIPPED, implies PASSED if no FAILED/ERROR propagated
has_passed_step = any(s.status == ExecutedStageStepResult.Status.PASSED for s in executed_steps_results)
if has_passed_step:
stage_result.overall_status = ExecutedStageResult.Status.PASSED
else: # No explicit FAILED/ERROR, no PASSED -> likely all SKIPPED or a mix that doesn't constitute overall failure
stage_result.overall_status = ExecutedStageResult.Status.SKIPPED
stage_result.message = stage_result.message or "所有步骤均已跳过或以非失败/错误状态完成."
stage_result.message = stage_result.message or stage_instance.description # Default message if none set
try:
self.logger.debug(f"调用 stage '{stage_id}' 的 after_stage 钩子. 当前阶段结果状态: {stage_result.overall_status.value}")
stage_instance.after_stage(stage_result=stage_result, stage_context=stage_context, global_api_spec=parsed_spec, api_group_name=api_group_name)
except Exception as e_asg:
self.logger.error(f"Stage '{stage_id}' 的 after_stage 钩子执行失败: {e_asg}", exc_info=True)
if stage_result.overall_status not in [ExecutedStageResult.Status.ERROR]:
stage_result.overall_status = ExecutedStageResult.Status.ERROR # Downgrade
stage_result.message = f"{stage_result.message or ''} | after_stage hook failed: {e_asg}".strip(" | ")
stage_result.executed_steps = executed_steps_results
stage_result.finalize_stage_result(final_context=stage_context)
self.logger.info(f"测试阶段 '{stage_id}' (API分组: '{api_group_name or 'Global'}') 执行完毕. 最终状态: {stage_result.overall_status.value}, 耗时: {stage_result.duration:.2f}s")
return stage_result
def run_stages_from_spec(self,
parsed_spec: ParsedAPISpec,
summary: TestSummary):
self.logger.info("开始执行API测试阶段 (Stages)...")
if not self.stage_registry or not self.stage_registry.get_all_stages():
self.logger.info("未加载任何API测试阶段。跳过阶段执行。")
summary.set_total_stages_defined(0) # Ensure this is set even if no stages run
return
stage_classes = self.stage_registry.get_all_stages()
summary.set_total_stages_defined(len(stage_classes))
self.logger.info(f"发现了 {len(stage_classes)} 个已定义的测试阶段: {[sc.id if hasattr(sc, 'id') else sc.__name__ for sc in stage_classes]}")
total_stages_considered_for_execution = 0 # 初始化计数器
api_groups_to_iterate: List[Optional[str]] = []
if isinstance(parsed_spec, ParsedYAPISpec):
if parsed_spec.categories:
api_groups_to_iterate.extend([cat.get('name') for cat in parsed_spec.categories if cat.get('name')])
if not api_groups_to_iterate:
self.logger.info("YAPI规范: 未找到已命名的分类,或分类列表为空。将阶段应用于整个规范 (api_group_name=None).")
api_groups_to_iterate.append(None)
elif isinstance(parsed_spec, ParsedSwaggerSpec):
# For Swagger, iterate through globally defined tags. Stages can then filter APIs based on these tags.
# If a stage is designed for broader application, it can ignore the specific group name.
if parsed_spec.tags:
api_groups_to_iterate.extend([tag.get('name') for tag in parsed_spec.tags if tag.get('name')])
if not api_groups_to_iterate:
self.logger.info("Swagger规范: 未找到已定义的标签,或标签列表为空。将阶段应用于整个规范 (api_group_name=None).")
api_groups_to_iterate.append(None)
else:
self.logger.warning(f"未知的解析规范类型: {type(parsed_spec)}。将阶段应用于整个规范 (api_group_name=None).")
api_groups_to_iterate.append(None)
if not api_groups_to_iterate: # Should always have at least [None] due to above logic
self.logger.error("逻辑错误: api_groups_to_iterate 为空,将默认为 [None].")
api_groups_to_iterate = [None]
self.logger.info(f"将针对 {len(api_groups_to_iterate)} 个API分组评估测试阶段: {api_groups_to_iterate}")
for stage_class_to_init in stage_classes:
stage_id_for_log = getattr(stage_class_to_init, 'id', stage_class_to_init.__name__)
stage_name_for_log = getattr(stage_class_to_init, 'name', stage_class_to_init.__name__)
self.logger.info(f"处理测试阶段定义: ID='{stage_id_for_log}', Name='{stage_name_for_log}'")
was_applicable_and_executed_at_least_once = False
# Create a template instance once to check fail_if_not_applicable_to_any_group
# This template instance doesn't need full group-specific data yet.
template_stage_instance_check = stage_class_to_init(
api_group_metadata={"name": "_template_check", "description": "用于预检查的模板实例"}, # Provide a default dict
apis_in_group=[],
global_api_spec=parsed_spec,
llm_service=self.llm_service,
stage_llm_config=self.stage_llm_config
)
for current_api_group_name in api_groups_to_iterate:
total_stages_considered_for_execution += 1 # 递增计数器
self.logger.debug(f"准备阶段 '{stage_id_for_log}' 的上下文针对API分组: '{current_api_group_name or 'Global'}'")
current_group_metadata: Optional[Dict[str, Any]] = None
apis_for_current_group_objects: List[Union[YAPIEndpoint, SwaggerEndpoint]] = []
if current_api_group_name:
if isinstance(parsed_spec, ParsedYAPISpec):
category_metadata_from_spec = next((cat for cat in parsed_spec.categories if cat.get('name') == current_api_group_name), None)
if not category_metadata_from_spec:
self.logger.warning(f"YAPI group '{current_api_group_name}': Could not find its metadata in parsed_spec.categories. Skipping this group for stage '{stage_id_for_log}'.")
current_group_metadata = {'name': current_api_group_name, 'description': '分类元数据未找到'}
else:
current_group_metadata = dict(category_metadata_from_spec)
effective_category_id: Optional[int] = None
for ep_lookup in parsed_spec.endpoints:
if isinstance(ep_lookup, YAPIEndpoint) and ep_lookup.category_name == current_api_group_name:
effective_category_id = ep_lookup.category_id
break
if effective_category_id is not None:
current_group_metadata['id'] = effective_category_id
apis_for_current_group_objects = [
api for api in parsed_spec.endpoints
if isinstance(api, YAPIEndpoint) and api.category_id == effective_category_id
]
self.logger.debug(f"For YAPI group '{current_api_group_name}' (resolved ID: {effective_category_id}), selected {len(apis_for_current_group_objects)} endpoint objects.")
else:
self.logger.warning(
f"YAPI group '{current_api_group_name}': Could not determine a definitive integer category ID "
f"from its endpoints (metadata ID from spec was '{current_group_metadata.get('id')}'). "
f"No APIs will be selected for this group for stage '{stage_id_for_log}'."
)
elif isinstance(parsed_spec, ParsedSwaggerSpec):
tag_metadata_obj = next((tag for tag in parsed_spec.tags if tag.get('name') == current_api_group_name), None)
if tag_metadata_obj:
current_group_metadata = {'name': current_api_group_name, 'description': tag_metadata_obj.get('description', '')}
else:
self.logger.warning(f"Swagger tag '{current_api_group_name}': Could not find its metadata in global tags. Will use name only.")
current_group_metadata = {'name': current_api_group_name, 'description': '标签定义可能仅在操作上'}
apis_for_current_group_objects = [
api for api in parsed_spec.endpoints
if isinstance(api, SwaggerEndpoint) and hasattr(api, 'tags') and isinstance(api.tags, list) and current_api_group_name in api.tags
]
self.logger.debug(f"For Swagger group '{current_api_group_name}', selected {len(apis_for_current_group_objects)} endpoint objects.")
else:
self.logger.warning(f"Unknown spec type ({type(parsed_spec)}) for group '{current_api_group_name}'.")
current_group_metadata = {"name": current_api_group_name, "description": f"未知规范类型 ({type(parsed_spec)}) 的分组"}
else:
current_group_metadata = {"name": "Global (所有API)", "description": "适用于规范中的所有API"}
if parsed_spec and parsed_spec.endpoints:
apis_for_current_group_objects = list(parsed_spec.endpoints)
else:
apis_for_current_group_objects = []
self.logger.debug(f"For Global context, selected {len(apis_for_current_group_objects)} endpoint objects.")
if current_api_group_name and not current_group_metadata:
self.logger.error(f"INTERNAL ERROR: Metadata for group '{current_api_group_name}' ended up None. Skipping.")
continue
current_apis_in_group_for_stage = apis_for_current_group_objects
self.logger.debug(f"为阶段 '{stage_id_for_log}' 和分组 '{current_api_group_name or 'Global'}' 实例化。API对象数量: {len(current_apis_in_group_for_stage)}. 元数据: {current_group_metadata}")
stage_instance_for_execution = stage_class_to_init(
api_group_metadata=current_group_metadata,
apis_in_group=current_apis_in_group_for_stage,
global_api_spec=parsed_spec,
llm_service=self.llm_service,
stage_llm_config=self.stage_llm_config
)
try:
self.logger.debug(f"检查阶段 '{stage_instance_for_execution.id}' 是否适用于API分组 '{current_api_group_name or 'Global'}'...")
# is_applicable_to_api_group expects the api_group_name string and the full spec
applicable = stage_instance_for_execution.is_applicable_to_api_group(
api_group_name=current_api_group_name, # Pass the name string
global_api_spec=parsed_spec
)
except Exception as e_is_app:
self.logger.error(f"检查阶段 '{stage_instance_for_execution.id}' 对分组 '{current_api_group_name or 'Global'}' 的适用性时出错: {e_is_app}", exc_info=True)
error_result = ExecutedStageResult(
stage_id=stage_instance_for_execution.id,
stage_name=stage_instance_for_execution.name,
description=stage_instance_for_execution.description,
api_group_metadata=current_group_metadata,
tags=stage_instance_for_execution.tags,
overall_status=ExecutedStageResult.Status.ERROR,
message=f"适用性检查期间出错: {e_is_app}"
)
error_result.finalize_stage_result(final_context={}) # Minimal context
summary.add_stage_result(error_result)
was_applicable_and_executed_at_least_once = True # Count as an attempt
continue # Skip to next group or stage
if applicable:
self.logger.info(f"测试阶段 '{stage_instance_for_execution.id}' 适用于API分组 '{current_api_group_name or 'Global'}'。开始执行...")
stage_execution_result = self.execute_single_stage(
stage_instance=stage_instance_for_execution,
parsed_spec=parsed_spec,
api_group_name=current_api_group_name # Pass the group name string
)
summary.add_stage_result(stage_execution_result)
was_applicable_and_executed_at_least_once = True # 确保在成功执行后设置
else: # This else corresponds to 'if applicable:'
self.logger.info(f"测试阶段 '{stage_instance_for_execution.id}' 不适用于API分组 '{current_api_group_name or 'Global'}'。跳过此分组。")
# fail_if_not_applicable_to_any_group 检查逻辑应使用 template_stage_instance_check
if not was_applicable_and_executed_at_least_once and template_stage_instance_check.fail_if_not_applicable_to_any_group:
self.logger.warning(f"测试阶段 '{template_stage_instance_check.id}' 未适用于任何API分组'fail_if_not_applicable_to_any_group' 为 True.")
# 使用 template_stage_instance_check 的属性来创建失败结果
failure_result = ExecutedStageResult(
stage_id=template_stage_instance_check.id,
stage_name=template_stage_instance_check.name,
description=template_stage_instance_check.description,
api_group_metadata=None, # 因为它没有应用于任何特定组
tags=template_stage_instance_check.tags,
overall_status=ExecutedStageResult.Status.FAILED,
message=f"阶段标记为 '必须应用' 但未适用于任何评估的分组: {[g for g in api_groups_to_iterate if g is not None]}"
)
failure_result.finalize_stage_result(final_context={})
summary.add_stage_result(failure_result)
self.logger.info(f"API Test Stage execution processed. Considered {total_stages_considered_for_execution} (stage_definition x api_group) combinations.")
return summary # <-- ADDED
def _execute_tests_from_parsed_spec(self,
parsed_spec: ParsedAPISpec,
summary: TestSummary,
categories: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
custom_test_cases_dir: Optional[str] = None
) -> TestSummary:
"""基于已解析的API规范对象执行测试用例。"""
if custom_test_cases_dir and (not self.test_case_registry or self.test_case_registry.test_cases_dir != custom_test_cases_dir):
self.logger.info(f"Re-initializing TestCaseRegistry with new directory: {custom_test_cases_dir}")
self.test_case_registry = TestCaseRegistry(test_cases_dir=custom_test_cases_dir)
endpoints_to_test: List[Union[YAPIEndpoint, SwaggerEndpoint, DMSEndpoint]] = []
if isinstance(parsed_spec, ParsedYAPISpec):
endpoints_to_test = parsed_spec.endpoints
if categories:
endpoints_to_test = [ep for ep in endpoints_to_test if hasattr(ep, 'category_name') and ep.category_name in categories]
elif isinstance(parsed_spec, ParsedSwaggerSpec):
endpoints_to_test = parsed_spec.endpoints
if tags:
endpoints_to_test = [ep for ep in endpoints_to_test if hasattr(ep, 'tags') and isinstance(ep.tags, list) and any(tag in ep.tags for tag in tags)]
elif isinstance(parsed_spec, ParsedDMSSpec):
endpoints_to_test = parsed_spec.endpoints
if categories:
endpoints_to_test = [ep for ep in endpoints_to_test if hasattr(ep, 'category_name') and ep.category_name in categories]
summary.set_total_endpoints_defined(summary.total_endpoints_defined + len(endpoints_to_test))
if self.test_case_registry:
total_applicable_tcs = sum(
len(self.test_case_registry.get_applicable_test_cases(ep.method.upper(), ep.path))
for ep in endpoints_to_test
)
summary.set_total_test_cases_applicable(summary.total_test_cases_applicable + total_applicable_tcs)
for endpoint in endpoints_to_test:
result = self.run_test_for_endpoint(endpoint, global_api_spec=parsed_spec)
summary.add_endpoint_result(result)
return summary
def run_tests_from_dms(self, domain_mapping_path: str,
categories: Optional[List[str]] = None,
custom_test_cases_dir: Optional[str] = None,
ignore_ssl: bool = False
) -> Tuple[TestSummary, Optional[ParsedAPISpec]]:
"""
通过动态DMS服务发现来执行测试。
"""
summary = TestSummary()
parser = InputParser()
self.logger.info("从DMS动态服务启动测试...")
# 如果方法参数中没有传递ignore_ssl使用实例的设置
actual_ignore_ssl = ignore_ssl if ignore_ssl else self.ignore_ssl
parsed_spec = parser.parse_dms_spec(domain_mapping_path, base_url=self.base_url, ignore_ssl=actual_ignore_ssl)
if not parsed_spec:
self.logger.error("无法从DMS服务解析API测试终止。")
summary.add_error("Could not parse APIs from DMS service.")
summary.finalize_summary()
return summary, None
self.run_stages_from_spec(parsed_spec, summary)
summary = self._execute_tests_from_parsed_spec(parsed_spec, summary, categories=categories, custom_test_cases_dir=custom_test_cases_dir)
summary.finalize_summary()
return summary, parsed_spec
def _validate_status_code(self, actual_code: int, expected_codes: List[int]) -> ValidationResult:
"""Helper to validate the HTTP status code."""
if actual_code in expected_codes:
return ValidationResult(passed=True, message=f"响应状态码 {actual_code} 符合预期。")
else:
return ValidationResult(
passed=False,
message=f"响应状态码不匹配。预期: {expected_codes}, 实际: {actual_code}",
details={"expected": expected_codes, "actual": actual_code}
)