90 lines
5.3 KiB
Python
90 lines
5.3 KiB
Python
import os
|
||
import importlib.util
|
||
import inspect
|
||
import logging
|
||
from typing import List, Type, Dict, Optional
|
||
|
||
from .scenario_framework import BaseAPIScenario # 从新的场景框架模块导入
|
||
|
||
class ScenarioRegistry:
|
||
"""
|
||
负责发现、加载和管理所有自定义的 BaseAPIScenario 类。
|
||
"""
|
||
def __init__(self, scenarios_dir: Optional[str] = None):
|
||
"""
|
||
初始化 ScenarioRegistry。
|
||
Args:
|
||
scenarios_dir: 存放自定义API场景 (.py 文件) 的目录路径。如果为None,则不进行发现。
|
||
"""
|
||
self.scenarios_dir = scenarios_dir
|
||
self.logger = logging.getLogger(__name__)
|
||
self._registry: Dict[str, Type[BaseAPIScenario]] = {}
|
||
self._scenario_classes: List[Type[BaseAPIScenario]] = []
|
||
if self.scenarios_dir:
|
||
self.discover_scenarios()
|
||
else:
|
||
self.logger.info("ScenarioRegistry 初始化时未提供 scenarios_dir,跳过场景发现。")
|
||
|
||
def discover_scenarios(self):
|
||
"""
|
||
扫描指定目录及其所有子目录,动态导入模块,并注册所有继承自 BaseAPIScenario 的类。
|
||
"""
|
||
if not self.scenarios_dir or not os.path.isdir(self.scenarios_dir):
|
||
self.logger.warning(f"API场景目录不存在或不是一个目录: {self.scenarios_dir}")
|
||
return
|
||
|
||
self.logger.info(f"开始从目录 '{self.scenarios_dir}' 及其子目录发现API场景...")
|
||
found_count = 0
|
||
for root_dir, _, files in os.walk(self.scenarios_dir):
|
||
for filename in files:
|
||
if filename.endswith(".py") and not filename.startswith("__"):
|
||
module_name = filename[:-3]
|
||
file_path = os.path.join(root_dir, filename)
|
||
try:
|
||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||
if spec and spec.loader:
|
||
module = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(module)
|
||
self.logger.debug(f"成功导入API场景模块: {module_name} 从 {file_path}")
|
||
|
||
for name, obj in inspect.getmembers(module):
|
||
if inspect.isclass(obj) and issubclass(obj, BaseAPIScenario) and obj is not BaseAPIScenario:
|
||
if not hasattr(obj, 'id') or not obj.id:
|
||
self.logger.error(f"API场景类 '{obj.__name__}' 在文件 '{file_path}' 中缺少有效的 'id' 属性,已跳过注册。")
|
||
continue
|
||
|
||
if obj.id in self._registry:
|
||
self.logger.warning(f"发现重复的API场景 ID: '{obj.id}' (来自类 '{obj.__name__}' in {file_path})。之前的定义将被覆盖。")
|
||
|
||
self._registry[obj.id] = obj
|
||
# 更新 _scenario_classes 列表
|
||
existing_class_indices = [i for i, sc_class in enumerate(self._scenario_classes) if sc_class.id == obj.id]
|
||
if existing_class_indices:
|
||
for index in sorted(existing_class_indices, reverse=True):
|
||
del self._scenario_classes[index]
|
||
|
||
self._scenario_classes.append(obj)
|
||
found_count += 1
|
||
self.logger.info(f"已注册API场景: '{obj.id}' ({getattr(obj, 'name', 'N/A')}) 来自类 '{obj.__name__}' (路径: {file_path})")
|
||
else:
|
||
self.logger.error(f"无法为文件 '{file_path}' 创建模块规范 (用于API场景)。")
|
||
except ImportError as e:
|
||
self.logger.error(f"导入API场景模块 '{module_name}' 从 '{file_path}' 失败: {e}", exc_info=True)
|
||
except AttributeError as e:
|
||
self.logger.error(f"在API场景模块 '{module_name}' ({file_path}) 中查找场景时出错: {e}", exc_info=True)
|
||
except Exception as e:
|
||
self.logger.error(f"处理API场景文件 '{file_path}' 时发生未知错误: {e}", exc_info=True)
|
||
|
||
# 场景通常不需要像单个测试用例那样排序执行顺序,除非有特定需求
|
||
# 如果需要,可以添加类似 execution_order 的属性并排序
|
||
# self._scenario_classes.sort(key=lambda sc_class: (getattr(sc_class, 'execution_order', 100), sc_class.__name__))
|
||
|
||
self.logger.info(f"API场景发现完成。总共注册了 {len(self._registry)} 个独特的API场景 (基于ID)。发现并加载了 {len(self._scenario_classes)} 个API场景类。")
|
||
|
||
def get_scenario_by_id(self, scenario_id: str) -> Optional[Type[BaseAPIScenario]]:
|
||
"""根据ID获取已注册的API场景类。"""
|
||
return self._registry.get(scenario_id)
|
||
|
||
def get_all_scenario_classes(self) -> List[Type[BaseAPIScenario]]:
|
||
"""获取所有已注册的API场景类列表。"""
|
||
return list(self._scenario_classes) # 返回副本 |