compliance/ddms_compliance_suite/scenario_registry.py
2025-06-05 15:17:51 +08:00

90 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import importlib.util
import inspect
import logging
from typing import List, Type, Dict, Optional
from .scenario_framework import BaseAPIScenario # 从新的场景框架模块导入
class ScenarioRegistry:
"""
负责发现、加载和管理所有自定义的 BaseAPIScenario 类。
"""
def __init__(self, scenarios_dir: Optional[str] = None):
"""
初始化 ScenarioRegistry。
Args:
scenarios_dir: 存放自定义API场景 (.py 文件) 的目录路径。如果为None则不进行发现。
"""
self.scenarios_dir = scenarios_dir
self.logger = logging.getLogger(__name__)
self._registry: Dict[str, Type[BaseAPIScenario]] = {}
self._scenario_classes: List[Type[BaseAPIScenario]] = []
if self.scenarios_dir:
self.discover_scenarios()
else:
self.logger.info("ScenarioRegistry 初始化时未提供 scenarios_dir跳过场景发现。")
def discover_scenarios(self):
"""
扫描指定目录及其所有子目录,动态导入模块,并注册所有继承自 BaseAPIScenario 的类。
"""
if not self.scenarios_dir or not os.path.isdir(self.scenarios_dir):
self.logger.warning(f"API场景目录不存在或不是一个目录: {self.scenarios_dir}")
return
self.logger.info(f"开始从目录 '{self.scenarios_dir}' 及其子目录发现API场景...")
found_count = 0
for root_dir, _, files in os.walk(self.scenarios_dir):
for filename in files:
if filename.endswith(".py") and not filename.startswith("__"):
module_name = filename[:-3]
file_path = os.path.join(root_dir, filename)
try:
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.logger.debug(f"成功导入API场景模块: {module_name}{file_path}")
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseAPIScenario) and obj is not BaseAPIScenario:
if not hasattr(obj, 'id') or not obj.id:
self.logger.error(f"API场景类 '{obj.__name__}' 在文件 '{file_path}' 中缺少有效的 'id' 属性,已跳过注册。")
continue
if obj.id in self._registry:
self.logger.warning(f"发现重复的API场景 ID: '{obj.id}' (来自类 '{obj.__name__}' in {file_path})。之前的定义将被覆盖。")
self._registry[obj.id] = obj
# 更新 _scenario_classes 列表
existing_class_indices = [i for i, sc_class in enumerate(self._scenario_classes) if sc_class.id == obj.id]
if existing_class_indices:
for index in sorted(existing_class_indices, reverse=True):
del self._scenario_classes[index]
self._scenario_classes.append(obj)
found_count += 1
self.logger.info(f"已注册API场景: '{obj.id}' ({getattr(obj, 'name', 'N/A')}) 来自类 '{obj.__name__}' (路径: {file_path})")
else:
self.logger.error(f"无法为文件 '{file_path}' 创建模块规范 (用于API场景)。")
except ImportError as e:
self.logger.error(f"导入API场景模块 '{module_name}''{file_path}' 失败: {e}", exc_info=True)
except AttributeError as e:
self.logger.error(f"在API场景模块 '{module_name}' ({file_path}) 中查找场景时出错: {e}", exc_info=True)
except Exception as e:
self.logger.error(f"处理API场景文件 '{file_path}' 时发生未知错误: {e}", exc_info=True)
# 场景通常不需要像单个测试用例那样排序执行顺序,除非有特定需求
# 如果需要,可以添加类似 execution_order 的属性并排序
# self._scenario_classes.sort(key=lambda sc_class: (getattr(sc_class, 'execution_order', 100), sc_class.__name__))
self.logger.info(f"API场景发现完成。总共注册了 {len(self._registry)} 个独特的API场景 (基于ID)。发现并加载了 {len(self._scenario_classes)} 个API场景类。")
def get_scenario_by_id(self, scenario_id: str) -> Optional[Type[BaseAPIScenario]]:
"""根据ID获取已注册的API场景类。"""
return self._registry.get(scenario_id)
def get_all_scenario_classes(self) -> List[Type[BaseAPIScenario]]:
"""获取所有已注册的API场景类列表。"""
return list(self._scenario_classes) # 返回副本