gongwenxin 1138668a72 v0.1
2025-05-26 17:10:38 +08:00

274 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import copy
from typing import Dict, List, Any, Optional, Union, Tuple
# 获取模块级别的 logger
logger = logging.getLogger(__name__)
def resolve_json_schema_references(
schema_to_resolve: Any,
full_api_spec: Dict[str, Any],
max_depth: int = 10,
current_depth: int = 0,
discard_refs: bool = True # 新增参数,默认为 True
) -> Any:
"""
递归解析JSON Schema中的$ref引用。
Args:
schema_to_resolve: 当前需要解析的schema部分 (可以是字典、列表或基本类型)。
full_api_spec: 完整的API规范字典用于查找$ref路径。
max_depth: 最大递归深度,防止无限循环。
current_depth: 当前递归深度。
discard_refs: 是否在解析前移除 $ref 和 $$ prefixed 键。
Returns:
解析了$ref的schema部分。
"""
if current_depth > max_depth:
logger.warning(f"达到最大$ref解析深度 ({max_depth}),可能存在循环引用。停止进一步解析。 Schema: {str(schema_to_resolve)[:200]}")
return schema_to_resolve
if isinstance(schema_to_resolve, dict):
current_dict_processing = dict(schema_to_resolve) # 操作副本
if discard_refs:
# 模式1: 丢弃 $ref 和 $$ 开头的键, 然后递归处理剩余值
ref_value = current_dict_processing.pop("$ref", None)
if ref_value is not None:
logger.debug(f"因 discard_refs=True丢弃 '$ref': {ref_value}")
keys_to_remove = [k for k in current_dict_processing if k.startswith("$$")]
for key_to_remove in keys_to_remove:
key_val = current_dict_processing.pop(key_to_remove, None)
logger.debug(f"因 discard_refs=True丢弃 '{key_to_remove}': {key_val}")
# current_dict_processing 已清理完毕,递归处理其值
resolved_children = {}
for key, value in current_dict_processing.items():
resolved_children[key] = resolve_json_schema_references(
value, full_api_spec, max_depth, current_depth + 1, discard_refs=discard_refs
)
return resolved_children
else:
# 模式2: 尝试解析 $ref (如果存在), 然后递归。$$ 开头的键会保留并递归处理。
if "$ref" in current_dict_processing:
ref_path = current_dict_processing["$ref"]
if not isinstance(ref_path, str) or not ref_path.startswith("#/"):
logger.warning(f"不支持的$ref格式或外部引用: {ref_path}。在非丢弃模式下,$ref将作为普通键值对处理。")
# 继续执行后续的常规递归,$ref 将作为 current_dict_processing 中的一个键
else:
path_parts = ref_path[2:].split('/')
target_component_root = full_api_spec
current_target_component = target_component_root
valid_path = True
try:
for part in path_parts:
if isinstance(current_target_component, list):
try:
part_idx = int(part)
current_target_component = current_target_component[part_idx]
except (ValueError, IndexError):
logger.error(f"路径部分 '{part}' (应为整数索引) 无效或越界于列表。路径: {ref_path}")
valid_path = False
break
elif isinstance(current_target_component, dict):
if part not in current_target_component:
logger.error(f"路径部分 '{part}' 在对象中未找到。路径: {ref_path}. 可用键: {list(current_target_component.keys())}")
valid_path = False
break
current_target_component = current_target_component[part]
else:
logger.error(f"尝试在非字典/列表类型 ({type(current_target_component)}) 中访问路径部分 '{part}'。路径: {ref_path}")
valid_path = False
break
if valid_path:
# $ref 解析成功
final_schema_after_ref_resolution = copy.deepcopy(current_target_component)
# 如果解析结果是字典,则将原始 $ref 位置的同级键合并(覆盖)进去
if isinstance(final_schema_after_ref_resolution, dict):
for key, value in current_dict_processing.items():
if key != "$ref": # 合并同级键
final_schema_after_ref_resolution[key] = value
# 如果解析结果不是字典(例如,一个数组或原始类型),则同级键实际上被丢弃,
# 因为返回的是 final_schema_after_ref_resolution 本身。这是 $ref 的标准行为之一。
logger.debug(f"成功解析 $ref: '{ref_path}'。将递归解析其内容(可能已与同级键合并)。")
return resolve_json_schema_references(
final_schema_after_ref_resolution, full_api_spec, max_depth, current_depth + 1, discard_refs=discard_refs
)
except Exception as e:
logger.error(f"解析$ref '{ref_path}' 时发生意外错误: {e}. 将尝试使用同级节点。", exc_info=False)
valid_path = False
# 如果 $ref 解析失败 (valid_path is False 或 try 块中出现异常)
if not valid_path:
logger.warning(f"$ref '{ref_path}' 解析失败。将移除 $ref 并处理该对象的其余部分。")
current_dict_processing.pop("$ref", None)
# 继续执行后续的常规递归,此时 current_dict_processing 已移除了失败的 $ref
# 常规递归 (模式2: 非丢弃模式 / $ref 已处理或移除)
resolved_children = {}
for key, value in current_dict_processing.items():
resolved_children[key] = resolve_json_schema_references(
value, full_api_spec, max_depth, current_depth + 1, discard_refs=discard_refs
)
return resolved_children
elif isinstance(schema_to_resolve, list):
return [resolve_json_schema_references(item, full_api_spec, max_depth, current_depth + 1, discard_refs=discard_refs) for item in schema_to_resolve]
else: # 原始类型
return schema_to_resolve
def util_find_removable_field_path_recursive(
current_schema: Dict[str, Any],
current_path: List[Union[str, int]], # Union added here
full_api_spec_for_refs: Dict[str, Any],
# logger_param: Optional[logging.Logger] = None # Option to pass logger
) -> Optional[List[Union[str, int]]]:
"""
(框架辅助方法) 递归查找第一个可移除的必填字段的路径。
"""
# effective_logger = logger_param or logger # Use passed logger or module logger
resolved_schema = resolve_json_schema_references(current_schema, full_api_spec_for_refs)
if not isinstance(resolved_schema, dict) or resolved_schema.get("type") != "object":
return None
required_fields_at_current_level = resolved_schema.get("required", [])
properties = resolved_schema.get("properties", {})
logger.debug(f"[Util] 递归查找路径: {current_path}, 当前层级必填字段: {required_fields_at_current_level}, 属性: {list(properties.keys())}")
if required_fields_at_current_level and properties:
for field_name in required_fields_at_current_level:
if field_name in properties:
logger.info(f"[Util] 策略1: 在路径 {'.'.join(map(str,current_path)) if current_path else 'root'} 找到可直接移除的必填字段: '{field_name}'")
return current_path + [field_name]
if properties:
for prop_name, prop_schema_orig in properties.items():
prop_schema = resolve_json_schema_references(prop_schema_orig, full_api_spec_for_refs)
if isinstance(prop_schema, dict) and prop_schema.get("type") == "array":
items_schema_orig = prop_schema.get("items")
if isinstance(items_schema_orig, dict):
items_schema = resolve_json_schema_references(items_schema_orig, full_api_spec_for_refs)
if isinstance(items_schema, dict) and items_schema.get("type") == "object":
item_required_fields = items_schema.get("required", [])
item_properties = items_schema.get("properties", {})
if item_required_fields and item_properties:
first_required_field_in_item = next((rf for rf in item_required_fields if rf in item_properties), None)
if first_required_field_in_item:
logger.info(f"[Util] 策略2: 在数组属性 '{prop_name}' (路径 {'.'.join(map(str,current_path)) if current_path else 'root'}) 的元素内找到必填字段: '{first_required_field_in_item}'. 路径: {current_path + [prop_name, 0, first_required_field_in_item]}")
return current_path + [prop_name, 0, first_required_field_in_item]
logger.debug(f"[Util] 在路径 {'.'.join(map(str,current_path)) if current_path else 'root'} 未通过任何策略找到可移除的必填字段。")
return None
def util_remove_value_at_path(
data_container: Any,
path: List[Union[str, int]],
# logger_param: Optional[logging.Logger] = None
) -> Tuple[Any, Any, bool]:
"""
(框架辅助方法) 从嵌套的字典/列表中移除指定路径的值。
返回 (修改后的容器, 被移除的值, 是否成功)。
"""
# effective_logger = logger_param or logger
if not path:
logger.error("[Util] util_remove_value_at_path: 路径不能为空。")
return data_container, None, False
if data_container is None:
if isinstance(path[0], str):
container_copy = {}
elif isinstance(path[0], int):
container_copy = []
else:
logger.error(f"[Util] util_remove_value_at_path: 路径的第一个元素 '{path[0]}' 类型未知。")
return data_container, None, False
else:
container_copy = copy.deepcopy(data_container)
current_level = container_copy
original_value = None
try:
for i, key_or_index in enumerate(path):
is_last_element = (i == len(path) - 1)
if is_last_element:
if isinstance(key_or_index, str):
if isinstance(current_level, dict) and key_or_index in current_level:
original_value = current_level.pop(key_or_index)
logger.info(f"[Util] 从路径 '{'.'.join(map(str,path))}' 成功移除字段 '{key_or_index}' (原值: '{original_value}')。")
return container_copy, original_value, True
elif isinstance(current_level, dict):
logger.warning(f"[Util] 路径的最后一部分 '{key_or_index}' (string key) 在对象中未找到。路径: {'.'.join(map(str,path))}")
return container_copy, None, False
else:
logger.error(f"[Util] 路径的最后一部分 '{key_or_index}' (string key) 期望父级是字典,但找到 {type(current_level)}。路径: {'.'.join(map(str,path))}")
return data_container, None, False
else:
if isinstance(current_level, list) and isinstance(key_or_index, int) and 0 <= key_or_index < len(current_level):
original_value = current_level.pop(key_or_index)
logger.info(f"[Util] 从路径 '{'.'.join(map(str,path))}' 成功移除索引 '{key_or_index}' 的元素 (原值: '{original_value}')。")
return container_copy, original_value, True
elif isinstance(current_level, list):
logger.warning(f"[Util] 路径的最后一部分索引 '{key_or_index}' 超出列表范围或类型不符。列表长度: {len(current_level)}. 路径: {'.'.join(map(str,path))}")
return container_copy, None, False
else:
logger.error(f"[Util] 路径的最后一部分 '{key_or_index}' 期望父级是列表,但找到 {type(current_level)}。路径: {'.'.join(map(str,path))}")
return data_container, None, False
else:
next_key_or_index = path[i+1]
if isinstance(key_or_index, str):
if not isinstance(current_level, dict):
if current_level is container_copy and not container_copy :
current_level = {}
if i == 0: container_copy = current_level
else:
logger.error(f"[Util] 无法在非根级别从非字典创建路径。")
return data_container, None, False
else:
logger.error(f"[Util] 路径期望字典,但在 '{key_or_index}' 处找到 {type(current_level)}")
return data_container, None, False
if isinstance(next_key_or_index, int):
if key_or_index not in current_level or not isinstance(current_level.get(key_or_index), list):
current_level[key_or_index] = []
current_level = current_level[key_or_index]
else:
if key_or_index not in current_level or not isinstance(current_level.get(key_or_index), dict):
current_level[key_or_index] = {}
current_level = current_level[key_or_index]
elif isinstance(key_or_index, int):
if not isinstance(current_level, list):
logger.error(f"[Util] 路径期望列表以应用索引 '{key_or_index}',但找到 {type(current_level)}")
return data_container, None, False
while len(current_level) <= key_or_index:
if isinstance(next_key_or_index, str):
current_level.append({})
else:
current_level.append([])
if isinstance(next_key_or_index, str):
if not isinstance(current_level[key_or_index], dict):
current_level[key_or_index] = {}
elif isinstance(next_key_or_index, int):
if not isinstance(current_level[key_or_index], list):
current_level[key_or_index] = []
current_level = current_level[key_or_index]
else:
logger.error(f"[Util] 路径部分 '{key_or_index}' 类型未知 ({type(key_or_index)}).")
return data_container, None, False
except Exception as e:
logger.error(f"[Util] 在准备移除字段路径 '{'.'.join(map(str,path))}' 时发生错误: {e}", exc_info=True)
return data_container, None, False
logger.error(f"[Util] util_remove_value_at_path 未能在循环内按预期返回。路径: {'.'.join(map(str,path))}")
return data_container, None, False