366 lines
13 KiB
Python
366 lines
13 KiB
Python
"""规则库核心模块"""
|
||
import logging
|
||
from typing import Dict, List, Optional, Type, Union, Any
|
||
|
||
from ..models.rule_models import AnyRule, BaseRule, RuleQuery, RuleCategory, TargetType, RuleLifecycle, RuleScope
|
||
from ..models.config_models import RuleRepositoryConfig
|
||
from .adapters.base_adapter import BaseRuleStorageAdapter
|
||
from .adapters.filesystem_adapter import FilesystemAdapter
|
||
from .yaml_adapter import YAMLAdapter
|
||
# 未来可能添加的其他适配器
|
||
# from .adapters.db_adapter import DatabaseAdapter
|
||
# from .adapters.in_memory_adapter import InMemoryAdapter
|
||
|
||
class RuleRepository:
|
||
"""
|
||
规则库模块的核心类。
|
||
负责通过合适的存储适配器管理和提供规则。
|
||
"""
|
||
|
||
def __init__(self, config: RuleRepositoryConfig):
|
||
"""
|
||
初始化规则库。
|
||
|
||
Args:
|
||
config: 规则库配置
|
||
"""
|
||
self.config = config
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
# 创建适当的存储适配器
|
||
self.adapters = self._create_adapters()
|
||
|
||
# 用于在内存中缓存规则 (如果启用了preload_rules)
|
||
self.rule_cache: Dict[str, Dict[str, AnyRule]] = {} # {rule_id: {version: rule}}
|
||
|
||
# 初始化适配器
|
||
for adapter in self.adapters:
|
||
adapter.initialize()
|
||
|
||
# 如果配置了预加载规则,则加载所有规则到内存
|
||
if self.config.preload_rules:
|
||
self._preload_rules()
|
||
|
||
def _create_adapters(self) -> List[BaseRuleStorageAdapter]:
|
||
"""根据配置创建适当的存储适配器。"""
|
||
adapters = []
|
||
storage_type = self.config.storage.type.lower()
|
||
|
||
if storage_type == "filesystem":
|
||
# 添加JSON规则适配器
|
||
adapters.append(FilesystemAdapter(
|
||
base_path=self.config.storage.path or "./rules"
|
||
))
|
||
|
||
# 添加YAML规则适配器
|
||
adapters.append(YAMLAdapter(
|
||
base_path=self.config.storage.path or "./rules"
|
||
))
|
||
|
||
# 未来可能添加的其他适配器类型
|
||
# elif storage_type == "database":
|
||
# adapters.append(DatabaseAdapter(
|
||
# connection_string=self.config.storage.connection_string
|
||
# ))
|
||
# elif storage_type == "in_memory":
|
||
# adapters.append(InMemoryAdapter())
|
||
else:
|
||
raise ValueError(f"Unsupported rule storage type: {storage_type}")
|
||
|
||
return adapters
|
||
|
||
def _preload_rules(self) -> None:
|
||
"""预加载所有规则到内存缓存。"""
|
||
self.logger.info("Preloading rules from storage...")
|
||
all_rule_ids = set()
|
||
|
||
# 从所有适配器收集规则ID
|
||
for adapter in self.adapters:
|
||
rule_ids = adapter.list_all_rule_ids()
|
||
all_rule_ids.update(rule_ids)
|
||
|
||
loaded_count = 0
|
||
|
||
for rule_id in all_rule_ids:
|
||
versions_by_adapter = []
|
||
|
||
# 从所有适配器收集规则版本
|
||
for adapter in self.adapters:
|
||
versions = adapter.get_rule_versions(rule_id)
|
||
if versions:
|
||
versions_by_adapter.append((adapter, versions))
|
||
|
||
if not versions_by_adapter:
|
||
continue
|
||
|
||
if rule_id not in self.rule_cache:
|
||
self.rule_cache[rule_id] = {}
|
||
|
||
# 对于每个适配器的每个版本,尝试加载规则
|
||
for adapter, versions in versions_by_adapter:
|
||
for version in versions:
|
||
rule = adapter.load_rule_by_id(rule_id, version)
|
||
if rule:
|
||
self.rule_cache[rule_id][version] = rule
|
||
loaded_count += 1
|
||
|
||
self.logger.info(f"Preloaded {loaded_count} rules from {len(all_rule_ids)} rule IDs")
|
||
|
||
def get_rule(self, rule_id: str, version: Optional[str] = None) -> Optional[AnyRule]:
|
||
"""
|
||
获取指定ID和版本的规则。
|
||
|
||
Args:
|
||
rule_id: 规则ID
|
||
version: 规则版本(如果未指定,则使用配置的默认版本策略)
|
||
|
||
Returns:
|
||
规则对象,如果未找到则返回None
|
||
"""
|
||
# 优先从缓存中获取,如果启用了预加载
|
||
if self.config.preload_rules and rule_id in self.rule_cache:
|
||
if version and version in self.rule_cache[rule_id]:
|
||
return self.rule_cache[rule_id][version]
|
||
elif not version and self.rule_cache[rule_id]:
|
||
# 获取最新版本
|
||
latest_version = self._get_latest_version(list(self.rule_cache[rule_id].keys()))
|
||
return self.rule_cache[rule_id].get(latest_version)
|
||
|
||
# 从适配器加载
|
||
for adapter in self.adapters:
|
||
rule = adapter.load_rule_by_id(rule_id, version)
|
||
if rule:
|
||
return rule
|
||
|
||
return None
|
||
|
||
def _get_latest_version(self, versions: List[str]) -> str:
|
||
"""简单地按字符串排序获取最新版本。"""
|
||
if not versions:
|
||
return ""
|
||
versions.sort()
|
||
return versions[-1]
|
||
|
||
def query_rules(self, query: Optional[RuleQuery] = None) -> List[AnyRule]:
|
||
"""
|
||
根据查询条件查询规则。
|
||
|
||
Args:
|
||
query: 规则查询条件,如果为None则使用默认查询
|
||
|
||
Returns:
|
||
匹配规则的列表
|
||
"""
|
||
query = query or RuleQuery()
|
||
|
||
# 从所有适配器查询规则
|
||
results = []
|
||
for adapter in self.adapters:
|
||
adapter_results = adapter.query_rules(query)
|
||
if adapter_results:
|
||
results.extend(adapter_results)
|
||
|
||
# 去重(可能不同适配器返回相同ID和版本的规则)
|
||
deduplicated = {}
|
||
for rule in results:
|
||
key = f"{rule.id}:{rule.version}"
|
||
if key not in deduplicated:
|
||
deduplicated[key] = rule
|
||
|
||
return list(deduplicated.values())
|
||
|
||
def get_rules_by_tags(self, tags: List[str], match_all: bool = False) -> List[AnyRule]:
|
||
"""
|
||
根据标签查询规则。
|
||
|
||
Args:
|
||
tags: 要匹配的标签列表。
|
||
match_all: 如果为True,则规则必须包含所有指定的标签;
|
||
如果为False(默认),则规则包含任何一个指定标签即可匹配。
|
||
|
||
Returns:
|
||
匹配标签的规则列表。
|
||
"""
|
||
if not tags:
|
||
return [] # 如果没有提供标签,返回空列表
|
||
|
||
# 获取所有规则进行过滤。可以考虑优化,如果规则量非常大,
|
||
# 且适配器支持基于标签的查询,则直接调用适配器。
|
||
# 目前,我们先在查询所有规则后进行内存过滤。
|
||
all_rules = self.query_rules(RuleQuery(is_enabled=True)) # 通常只查询启用的规则
|
||
|
||
matched_rules = []
|
||
tag_set_query = set(tag.lower() for tag in tags) # 查询标签转换为小写集合以进行不区分大小写的比较
|
||
|
||
for rule in all_rules:
|
||
if not rule.tags: # 如果规则没有标签,则跳过
|
||
continue
|
||
|
||
rule_tags_set = set(t.lower() for t in rule.tags) # 规则的标签也转换为小写集合
|
||
|
||
if match_all:
|
||
# 需要匹配所有查询标签
|
||
if tag_set_query.issubset(rule_tags_set):
|
||
matched_rules.append(rule)
|
||
else:
|
||
# 只需要匹配任何一个查询标签
|
||
if not tag_set_query.isdisjoint(rule_tags_set): # 如果交集不为空
|
||
matched_rules.append(rule)
|
||
|
||
return matched_rules
|
||
|
||
def save_rule(self, rule: BaseRule) -> bool:
|
||
"""
|
||
保存规则到存储。
|
||
|
||
Args:
|
||
rule: 要保存的规则
|
||
|
||
Returns:
|
||
操作是否成功
|
||
"""
|
||
# 根据规则类别选择合适的适配器
|
||
adapter_to_use = self.adapters[0] # 默认使用第一个适配器
|
||
|
||
# 如果是YAML格式的规则,使用YAML适配器
|
||
if hasattr(rule, 'code') and rule.code:
|
||
for adapter in self.adapters:
|
||
if isinstance(adapter, YAMLAdapter):
|
||
adapter_to_use = adapter
|
||
break
|
||
|
||
result = adapter_to_use.save_rule(rule)
|
||
|
||
# 如果成功保存且启用了预加载,更新缓存
|
||
if result and self.config.preload_rules:
|
||
if rule.id not in self.rule_cache:
|
||
self.rule_cache[rule.id] = {}
|
||
self.rule_cache[rule.id][rule.version] = rule
|
||
|
||
return result
|
||
|
||
def delete_rule(self, rule_id: str, version: Optional[str] = None) -> bool:
|
||
"""
|
||
从存储中删除规则。
|
||
|
||
Args:
|
||
rule_id: 规则ID
|
||
version: 如果指定,仅删除该版本;否则删除所有版本
|
||
|
||
Returns:
|
||
操作是否成功
|
||
"""
|
||
# 从所有适配器删除规则
|
||
overall_result = True
|
||
for adapter in self.adapters:
|
||
try:
|
||
result = adapter.delete_rule(rule_id, version)
|
||
if not result:
|
||
overall_result = False
|
||
except Exception as e:
|
||
self.logger.error(f"Error deleting rule {rule_id} (version={version}) from adapter {adapter.__class__.__name__}: {e}")
|
||
overall_result = False
|
||
|
||
# 如果启用了预加载,更新缓存
|
||
if self.config.preload_rules:
|
||
if version and rule_id in self.rule_cache:
|
||
# 删除特定版本
|
||
if version in self.rule_cache[rule_id]:
|
||
del self.rule_cache[rule_id][version]
|
||
# 如果该规则没有更多版本,删除整个规则条目
|
||
if not self.rule_cache[rule_id]:
|
||
del self.rule_cache[rule_id]
|
||
elif rule_id in self.rule_cache:
|
||
# 删除所有版本
|
||
del self.rule_cache[rule_id]
|
||
|
||
return overall_result
|
||
|
||
def get_rules_for_target(self, target_type: TargetType, target_id: str) -> List[AnyRule]:
|
||
"""
|
||
获取适用于特定目标的规则。
|
||
这是一个便捷方法,用于当TestExecutor或JSONSchemaValidator需要找到适用于特定API操作或数据对象的规则。
|
||
|
||
Args:
|
||
target_type: 目标类型(如APIRequest, APIResponse, DataObject)
|
||
target_id: 目标标识符(如API操作ID, 数据对象名称)
|
||
|
||
Returns:
|
||
适用于该目标的规则列表
|
||
"""
|
||
query = RuleQuery(
|
||
target_type=target_type,
|
||
target_identifier=target_id,
|
||
is_enabled=True
|
||
)
|
||
return self.query_rules(query)
|
||
|
||
def get_rules_by_lifecycle(self, lifecycle: RuleLifecycle, target_type: Optional[TargetType] = None) -> List[AnyRule]:
|
||
"""
|
||
获取适用于特定生命周期阶段的规则。
|
||
|
||
Args:
|
||
lifecycle: 规则适用的生命周期阶段
|
||
target_type: 可选的目标类型过滤
|
||
|
||
Returns:
|
||
适用于该生命周期阶段的规则列表
|
||
"""
|
||
query = RuleQuery(
|
||
lifecycle=lifecycle,
|
||
target_type=target_type,
|
||
is_enabled=True
|
||
)
|
||
return self.query_rules(query)
|
||
|
||
def get_rules_by_scope(self, scope: RuleScope, target_type: Optional[TargetType] = None) -> List[AnyRule]:
|
||
"""
|
||
获取适用于特定作用域的规则。
|
||
|
||
Args:
|
||
scope: 规则的作用域
|
||
target_type: 可选的目标类型过滤
|
||
|
||
Returns:
|
||
适用于该作用域的规则列表
|
||
"""
|
||
query = RuleQuery(
|
||
scope=scope,
|
||
target_type=target_type,
|
||
is_enabled=True
|
||
)
|
||
return self.query_rules(query)
|
||
|
||
def get_schema_for_target(self, target_type: TargetType, target_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取适用于特定目标的JSON Schema。
|
||
这是一个便捷方法,用于当JSONSchemaValidator需要找到适用于特定API操作或数据对象的JSON Schema。
|
||
|
||
Args:
|
||
target_type: 目标类型(如APIRequest, APIResponse, DataObject)
|
||
target_id: 目标标识符(如API操作ID, 数据对象名称)
|
||
|
||
Returns:
|
||
JSON Schema字典,如果未找到则返回None
|
||
"""
|
||
query = RuleQuery(
|
||
category=RuleCategory.JSON_SCHEMA,
|
||
target_type=target_type,
|
||
target_identifier=target_id,
|
||
is_enabled=True
|
||
)
|
||
|
||
schemas = self.query_rules(query)
|
||
if not schemas:
|
||
return None
|
||
|
||
# 如果有多个匹配的Schema规则,使用最新版本的
|
||
# 注意:这里可以根据需要实现更复杂的选择逻辑
|
||
schemas.sort(key=lambda x: x.version)
|
||
latest_schema = schemas[-1]
|
||
|
||
# 假设是JSONSchemaDefinition类型
|
||
if hasattr(latest_schema, 'schema_content'):
|
||
return latest_schema.schema_content
|
||
|
||
return None |