128 lines
6.2 KiB
Python
128 lines
6.2 KiB
Python
import os
|
||
import json
|
||
from typing import Dict, Any, Optional, List
|
||
from ddms_compliance_suite.test_framework_core import BaseAPITestCase, TestSeverity, ValidationResult, APIRequestContext, APIResponseContext
|
||
|
||
class LLMComplianceCheckTestCase(BaseAPITestCase):
|
||
id = "TC-LLM-COMPLIANCE-001"
|
||
name = "LLM合规性综合检查"
|
||
description = "读取固定的合规性标准列表,将API所有关键信息(url、headers、params、query、body、示例响应等)发送给大模型,让其判断是否通过并给出理由。"
|
||
severity = TestSeverity.MEDIUM
|
||
tags = ["llm", "compliance", "auto-eval"]
|
||
execution_order = 99
|
||
|
||
def __init__(self, endpoint_spec: Dict[str, Any], global_api_spec: Dict[str, Any], json_schema_validator: Optional[Any] = None, llm_service: Optional[Any] = None):
|
||
super().__init__(endpoint_spec, global_api_spec, json_schema_validator=json_schema_validator, llm_service=llm_service)
|
||
# 读取合规性标准
|
||
criteria_path = os.path.join(os.path.dirname(__file__), "compliance_criteria.json")
|
||
with open(criteria_path, "r", encoding="utf-8") as f:
|
||
self.compliance_criteria = json.load(f)
|
||
self.logger.info(f"已加载合规性标准: {self.compliance_criteria}")
|
||
|
||
def validate_response(self, response_context: APIResponseContext, request_context: APIRequestContext) -> List[ValidationResult]:
|
||
results = []
|
||
|
||
# 如果合规性标准列表为空,则跳过测试
|
||
if not self.compliance_criteria:
|
||
return [ValidationResult(
|
||
passed=True,
|
||
message="合规性标准列表为空,跳过LLM合规性检查。",
|
||
details={"reason": "compliance_criteria.json is empty or contains an empty list."}
|
||
)]
|
||
|
||
# 收集API所有关键信息,包括实例数据和Schema定义
|
||
api_info = {
|
||
# API元数据
|
||
"path_template": self.endpoint_spec.get("path"),
|
||
"method": request_context.method,
|
||
# "operationId": self.endpoint_spec.get("operationId"),
|
||
"title": self.endpoint_spec.get("summary") or self.endpoint_spec.get("title"),
|
||
"description": self.endpoint_spec.get("description") or self.endpoint_spec.get("desc"),
|
||
"tags": self.endpoint_spec.get("tags"),
|
||
|
||
# API Schema 定义 (从 endpoint_spec 获取)
|
||
"schema_parameters": self.endpoint_spec.get("parameters"),
|
||
"schema_request_body": self.endpoint_spec.get("requestBody"),
|
||
"schema_responses": self.endpoint_spec.get("responses"),
|
||
|
||
# API 调用实例数据 (从 request_context 和 response_context 获取)
|
||
"instance_url": request_context.url,
|
||
"instance_request_headers": dict(request_context.headers) if hasattr(request_context, "headers") else {},
|
||
"instance_query_params": getattr(request_context, "query_params", {}),
|
||
"instance_path_params": getattr(request_context, "path_params", {}),
|
||
"instance_request_body": getattr(request_context, "body", None),
|
||
"instance_response_status": response_context.status_code,
|
||
"instance_response_headers": dict(response_context.headers) if hasattr(response_context, "headers") else {},
|
||
"instance_response_body": response_context.text_content if hasattr(response_context, "text_content") else None
|
||
}
|
||
# 日志打印所有API信息
|
||
self.logger.info("LLM合规性检查-API信息收集: " + json.dumps(api_info, ensure_ascii=False, indent=2))
|
||
self.logger.info("LLM合规性检查-标准: " + json.dumps(self.compliance_criteria, ensure_ascii=False, indent=2))
|
||
|
||
if not self.llm_service:
|
||
results.append(ValidationResult(
|
||
passed=True,
|
||
message="LLM服务不可用,跳过本用例。",
|
||
details={"reason": "llm_service is None"}
|
||
))
|
||
return results
|
||
|
||
# 构建prompt
|
||
prompt = f"""
|
||
你是一位API合规性专家。请根据以下合规性标准,对给定的API调用信息进行逐条评估。每条标准请给出是否通过(true/false)和理由。
|
||
|
||
合规性标准:
|
||
{json.dumps(self.compliance_criteria, ensure_ascii=False, indent=2)}
|
||
|
||
API信息:
|
||
{json.dumps(api_info, ensure_ascii=False, indent=2)}
|
||
|
||
请以如下JSON格式输出:
|
||
[
|
||
{{"criterion": "标准内容", "passed": true/false, "reason": "理由"}},
|
||
...
|
||
]
|
||
"""
|
||
messages = [
|
||
{"role": "system", "content": "你是一位API合规性专家,输出必须是严格的JSON数组。"},
|
||
{"role": "user", "content": prompt}
|
||
]
|
||
self.logger.info("发送给LLM的prompt: " + prompt)
|
||
llm_response_str = self.llm_service._execute_chat_completion_request(
|
||
messages=messages,
|
||
max_tokens=2048,
|
||
temperature=0.2
|
||
)
|
||
if not llm_response_str:
|
||
results.append(ValidationResult(
|
||
passed=False,
|
||
message="未能从LLM获取响应。",
|
||
details={"prompt": prompt}
|
||
))
|
||
return results
|
||
self.logger.info(f"LLM原始响应: {llm_response_str}")
|
||
try:
|
||
cleaned = llm_response_str.strip()
|
||
if cleaned.startswith("```json"):
|
||
cleaned = cleaned[7:]
|
||
if cleaned.endswith("```"):
|
||
cleaned = cleaned[:-3]
|
||
llm_result = json.loads(cleaned)
|
||
if not isinstance(llm_result, list):
|
||
raise ValueError("LLM返回的不是JSON数组")
|
||
for item in llm_result:
|
||
criterion = item.get("criterion", "未知标准")
|
||
passed = item.get("passed", False)
|
||
reason = item.get("reason", "无理由")
|
||
results.append(ValidationResult(
|
||
passed=passed,
|
||
message=f"[{criterion}] {'通过' if passed else '不通过'}: {reason}",
|
||
details={"criterion": criterion, "llm_reason": reason}
|
||
))
|
||
except Exception as e:
|
||
results.append(ValidationResult(
|
||
passed=False,
|
||
message=f"LLM响应解析失败: {e}",
|
||
details={"raw_llm_response": llm_response_str}
|
||
))
|
||
return results |