import os import importlib.util import inspect import logging from typing import List, Type, Dict, Optional from .scenario_framework import BaseAPIScenario # 从新的场景框架模块导入 class ScenarioRegistry: """ 负责发现、加载和管理所有自定义的 BaseAPIScenario 类。 """ def __init__(self, scenarios_dir: Optional[str] = None): """ 初始化 ScenarioRegistry。 Args: scenarios_dir: 存放自定义API场景 (.py 文件) 的目录路径。如果为None,则不进行发现。 """ self.scenarios_dir = scenarios_dir self.logger = logging.getLogger(__name__) self._registry: Dict[str, Type[BaseAPIScenario]] = {} self._scenario_classes: List[Type[BaseAPIScenario]] = [] if self.scenarios_dir: self.discover_scenarios() else: self.logger.info("ScenarioRegistry 初始化时未提供 scenarios_dir,跳过场景发现。") def discover_scenarios(self): """ 扫描指定目录及其所有子目录,动态导入模块,并注册所有继承自 BaseAPIScenario 的类。 """ if not self.scenarios_dir or not os.path.isdir(self.scenarios_dir): self.logger.warning(f"API场景目录不存在或不是一个目录: {self.scenarios_dir}") return self.logger.info(f"开始从目录 '{self.scenarios_dir}' 及其子目录发现API场景...") found_count = 0 for root_dir, _, files in os.walk(self.scenarios_dir): for filename in files: if filename.endswith(".py") and not filename.startswith("__"): module_name = filename[:-3] file_path = os.path.join(root_dir, filename) try: spec = importlib.util.spec_from_file_location(module_name, file_path) if spec and spec.loader: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self.logger.debug(f"成功导入API场景模块: {module_name} 从 {file_path}") for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, BaseAPIScenario) and obj is not BaseAPIScenario: if not hasattr(obj, 'id') or not obj.id: self.logger.error(f"API场景类 '{obj.__name__}' 在文件 '{file_path}' 中缺少有效的 'id' 属性,已跳过注册。") continue if obj.id in self._registry: self.logger.warning(f"发现重复的API场景 ID: '{obj.id}' (来自类 '{obj.__name__}' in {file_path})。之前的定义将被覆盖。") self._registry[obj.id] = obj # 更新 _scenario_classes 列表 existing_class_indices = [i for i, sc_class in enumerate(self._scenario_classes) if sc_class.id == obj.id] if existing_class_indices: for index in sorted(existing_class_indices, reverse=True): del self._scenario_classes[index] self._scenario_classes.append(obj) found_count += 1 self.logger.info(f"已注册API场景: '{obj.id}' ({getattr(obj, 'name', 'N/A')}) 来自类 '{obj.__name__}' (路径: {file_path})") else: self.logger.error(f"无法为文件 '{file_path}' 创建模块规范 (用于API场景)。") except ImportError as e: self.logger.error(f"导入API场景模块 '{module_name}' 从 '{file_path}' 失败: {e}", exc_info=True) except AttributeError as e: self.logger.error(f"在API场景模块 '{module_name}' ({file_path}) 中查找场景时出错: {e}", exc_info=True) except Exception as e: self.logger.error(f"处理API场景文件 '{file_path}' 时发生未知错误: {e}", exc_info=True) # 场景通常不需要像单个测试用例那样排序执行顺序,除非有特定需求 # 如果需要,可以添加类似 execution_order 的属性并排序 # self._scenario_classes.sort(key=lambda sc_class: (getattr(sc_class, 'execution_order', 100), sc_class.__name__)) self.logger.info(f"API场景发现完成。总共注册了 {len(self._registry)} 个独特的API场景 (基于ID)。发现并加载了 {len(self._scenario_classes)} 个API场景类。") def get_scenario_by_id(self, scenario_id: str) -> Optional[Type[BaseAPIScenario]]: """根据ID获取已注册的API场景类。""" return self._registry.get(scenario_id) def get_all_scenario_classes(self) -> List[Type[BaseAPIScenario]]: """获取所有已注册的API场景类列表。""" return list(self._scenario_classes) # 返回副本