253 lines
13 KiB
Python
253 lines
13 KiB
Python
import asyncio
|
||
import json
|
||
import traceback
|
||
from typing import List, Dict, Any
|
||
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
from mcp import ClientSession
|
||
from mcp.types import Tool, TextContent
|
||
|
||
from llm.llm_service import LLMService
|
||
|
||
# --- 配置区 ---
|
||
SERVER_ENDPOINTS = {
|
||
"api_caller": "http://127.0.0.1:8001/mcp",
|
||
"schema_validator": "http://127.0.0.1:8002/mcp",
|
||
"dms_provider": "http://127.0.0.1:8003/mcp",
|
||
"test_manager": "http://127.0.0.1:8004/mcp",
|
||
}
|
||
MAX_AGENT_LOOPS = 50
|
||
|
||
def mcp_tools_to_openai_format(mcp_tools: List[Tool]) -> List[Dict[str, Any]]:
|
||
"""
|
||
将MCP工具列表转换为OpenAI工具格式。
|
||
"""
|
||
openai_tools = []
|
||
for tool in mcp_tools:
|
||
# tool is a mcp.types.Tool object, which has .name, .description, and .inputSchema
|
||
openai_tools.append({
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool.name,
|
||
"description": tool.description or "",
|
||
"parameters": tool.inputSchema or {"type": "object", "properties": {}}
|
||
}
|
||
})
|
||
return openai_tools
|
||
|
||
async def get_structured_response(tool_response: Any) -> Dict[str, Any]:
|
||
"""
|
||
健壮地从工具调用响应中获取结构化内容。
|
||
能处理SDK未能自动解析JSON,而是将其放入TextContent的情况。
|
||
"""
|
||
if tool_response.structuredContent:
|
||
# 正常情况,SDK已成功解析
|
||
return tool_response.structuredContent
|
||
|
||
# 异常情况:尝试从TextContent手动解析JSON
|
||
if tool_response.content and isinstance(tool_response.content[0], TextContent):
|
||
try:
|
||
json_text = tool_response.content[0].text
|
||
parsed_json = json.loads(json_text)
|
||
return parsed_json
|
||
except (json.JSONDecodeError, IndexError) as e:
|
||
# 如果手动解析也失败,则抛出致命错误
|
||
raise RuntimeError(f"Failed to manually parse JSON from TextContent: {e}. Raw text: '{json_text}'")
|
||
|
||
# 如果既没有structuredContent,也没有可解析的TextContent,则抛出致命错误
|
||
raise RuntimeError("Tool call returned no structuredContent and no parsable TextContent.")
|
||
|
||
|
||
async def execute_task(task: Dict, tool_to_session_map: Dict, openai_tools: List[Dict]):
|
||
"""
|
||
为一个通用的、由prompt驱动的任务执行完整的、隔离的测试生命周期。
|
||
"""
|
||
llm_service = LLMService(tools=openai_tools)
|
||
|
||
task_name = task['name']
|
||
prompt = task['prompt']
|
||
|
||
print(f"\n>>>> Starting Task: {task_name} <<<<")
|
||
llm_service.start_new_task(prompt)
|
||
|
||
# 针对当前任务的子任务循环
|
||
for sub_loop in range(25): # 单个任务的测试循环上限
|
||
print("\n" + "="*20 + f" Sub-Loop for '{task_name}' ({sub_loop+1}/25) " + "="*20)
|
||
|
||
tool_name, tool_args, tool_call_id = llm_service.execute_completion()
|
||
|
||
if not tool_name:
|
||
print(f"Agent: LLM did not request a tool call for task '{task_name}'. It might be confused. Ending task.")
|
||
# 即使LLM困惑,我们仍然尝试记录一个失败结果,如果record_test_result可用的话
|
||
record_session = tool_to_session_map.get("record_test_result")
|
||
if record_session:
|
||
# 我们需要从prompt中猜测api_id,这很脆弱,但比什么都不做要好
|
||
import re
|
||
match = re.search(r"API 模型 '([^']+)'", prompt)
|
||
api_id_guess = match.group(1) if match else "unknown"
|
||
await record_session.call_tool("record_test_result", {"api_id": api_id_guess, "task_name": task_name, "status": "failed", "details": "LLM got confused and stopped calling tools."})
|
||
return # 结束此任务
|
||
|
||
# 核心逻辑:如果LLM调用了record_test_result,说明这个任务结束了
|
||
if tool_name == "record_test_result":
|
||
print(f"Agent: LLM is recording result for task '{task_name}'. Task is complete.")
|
||
record_session = tool_to_session_map.get("record_test_result")
|
||
if record_session:
|
||
# 将任务名称加入到参数中,以便更好地跟踪
|
||
tool_args['task_name'] = task_name
|
||
await record_session.call_tool(tool_name, tool_args)
|
||
return # 核心修复:使用return退出此任务的函数
|
||
|
||
if tool_name == "error_malformed_json":
|
||
error_info = tool_args
|
||
print(f"Agent: Detected a malformed JSON from LLM for tool '{error_info['tool_name']}'. Asking for correction.")
|
||
correction_request = f"你上次试图调用工具 '{error_info['tool_name']}',但提供的参数不是一个有效的JSON。错误是:{error_info['error']}。这是你提供的错误参数:'{error_info['malformed_arguments']}'。请修正这个错误,并重新调用该工具。"
|
||
llm_service.add_user_message(correction_request)
|
||
continue
|
||
|
||
if tool_name in tool_to_session_map:
|
||
try:
|
||
target_session = tool_to_session_map[tool_name]
|
||
result = await target_session.call_tool(tool_name, tool_args)
|
||
|
||
structured_result = await get_structured_response(result)
|
||
tool_result_str = json.dumps(structured_result, ensure_ascii=False, indent=2) if structured_result else "Tool executed successfully."
|
||
|
||
print(f"Agent: Tool '{tool_name}' executed for '{task_name}'. Result: {tool_result_str}")
|
||
llm_service.add_tool_call_response(tool_call_id, tool_result_str)
|
||
except Exception as e:
|
||
error_message = f"An exception occurred while calling tool {tool_name} for '{task_name}': {e}"
|
||
print(f"Agent: {error_message}")
|
||
traceback.print_exc()
|
||
llm_service.add_tool_call_response(tool_call_id, error_message)
|
||
else:
|
||
error_message = f"Error: LLM tried to call an unknown tool '{tool_name}' for task '{task_name}'."
|
||
print(f"Agent: {error_message}")
|
||
llm_service.add_tool_call_response(tool_call_id, error_message)
|
||
|
||
print(f"Agent: Reached sub-loop limit for task '{task_name}'. Recording as failed and moving on.")
|
||
record_session = tool_to_session_map.get("record_test_result")
|
||
if record_session:
|
||
import re
|
||
match = re.search(r"API 模型 '([^']+)'", prompt)
|
||
api_id_guess = match.group(1) if match else "unknown"
|
||
await record_session.call_tool("record_test_result", {"api_id": api_id_guess, "task_name": task_name, "status": "failed", "details": "Reached sub-loop limit."})
|
||
|
||
|
||
async def main():
|
||
print("LLM-Powered Agent starting...")
|
||
|
||
# 使用 `async with` 来确保所有会话都能被正确关闭
|
||
async with streamablehttp_client(SERVER_ENDPOINTS["api_caller"]) as (r1, w1, _), \
|
||
streamablehttp_client(SERVER_ENDPOINTS["schema_validator"]) as (r2, w2, _), \
|
||
streamablehttp_client(SERVER_ENDPOINTS["dms_provider"]) as (r3, w3, _), \
|
||
streamablehttp_client(SERVER_ENDPOINTS["test_manager"]) as (r4, w4, _):
|
||
|
||
print("Agent: All MCP server connections established.")
|
||
|
||
async with ClientSession(r1, w1) as s1, ClientSession(r2, w2) as s2, ClientSession(r3, w3) as s3, ClientSession(r4, w4) as s4:
|
||
|
||
await asyncio.gather(s1.initialize(), s2.initialize(), s3.initialize(), s4.initialize())
|
||
|
||
tool_to_session_map = {tool.name: s1 for tool in (await s1.list_tools()).tools}
|
||
tool_to_session_map.update({tool.name: s2 for tool in (await s2.list_tools()).tools})
|
||
tool_to_session_map.update({tool.name: s3 for tool in (await s3.list_tools()).tools})
|
||
tool_to_session_map.update({tool.name: s4 for tool in (await s4.list_tools()).tools})
|
||
|
||
all_mcp_tools = list(tool_to_session_map.keys())
|
||
print(f"Total tools found: {len(all_mcp_tools)}")
|
||
|
||
openai_tools = mcp_tools_to_openai_format([tool for session in [s1, s2, s3, s4] for tool in (await session.list_tools()).tools])
|
||
print("Agent: LLM Service tools prepared.")
|
||
|
||
# --- Agent主导的宏观测试流程 ---
|
||
|
||
# 1. 获取所有待测试的API
|
||
print("\n" + "="*20 + " Phase 1: Fetching APIs " + "="*20)
|
||
get_api_list_session = tool_to_session_map.get("get_api_list")
|
||
if not get_api_list_session:
|
||
raise RuntimeError("Critical Error: 'get_api_list' tool not found.")
|
||
|
||
api_list_result = await get_api_list_session.call_tool("get_api_list", {})
|
||
api_list_structured = await get_structured_response(api_list_result)
|
||
response_data = api_list_structured.get("result", api_list_structured)
|
||
api_records = response_data.get('records', [])
|
||
api_ids_to_test = [record['id'] for record in api_records if 'id' in record]
|
||
|
||
if not api_ids_to_test:
|
||
raise RuntimeError(f"Critical Error: DMSProviderServer returned an empty list of APIs.")
|
||
print(f"Agent: Found {len(api_ids_to_test)} APIs to test: {api_ids_to_test}")
|
||
|
||
# 2. 加载任务模板
|
||
print("\n" + "="*20 + " Phase 2: Loading Task Templates " + "="*20)
|
||
try:
|
||
with open('compliance-mcp-agent/tasks.json', 'r', encoding='utf-8') as f:
|
||
task_templates = json.load(f)
|
||
print(f"Agent: Loaded {len(task_templates)} task templates.")
|
||
except FileNotFoundError:
|
||
raise RuntimeError("Critical Error: 'tasks.json' not found in 'compliance-mcp-agent/' directory.")
|
||
except json.JSONDecodeError as e:
|
||
raise RuntimeError(f"Critical Error: Failed to parse 'tasks.json'. Error: {e}")
|
||
|
||
# 3. 初始化测试计划
|
||
print("\n" + "="*20 + " Phase 3: Initializing Test Plan " + "="*20)
|
||
initialize_plan_session = tool_to_session_map.get("initialize_test_plan")
|
||
if not initialize_plan_session:
|
||
raise RuntimeError("Critical Error: 'initialize_test_plan' tool not found.")
|
||
|
||
total_task_count = len(api_ids_to_test) * len(task_templates)
|
||
print(f"Agent: Initializing test plan for {total_task_count} total tasks ({len(api_ids_to_test)} APIs x {len(task_templates)} templates)...")
|
||
init_result = await initialize_plan_session.call_tool("initialize_test_plan", {"api_ids": api_ids_to_test})
|
||
init_structured = await get_structured_response(init_result)
|
||
init_response_data = init_structured.get("result", init_structured)
|
||
if init_response_data.get("status") != "success":
|
||
raise RuntimeError(f"Failed to initialize test plan. Reason: {init_response_data.get('message')}")
|
||
print("Agent: Test plan initialized successfully in TestManager.")
|
||
|
||
# 4. 主执行循环 (M x N)
|
||
print("\n" + "="*20 + " Phase 4: Main Execution Loop " + "="*20)
|
||
|
||
execution_tasks = []
|
||
for api_id in api_ids_to_test:
|
||
for template in task_templates:
|
||
# 动态生成任务
|
||
final_prompt = template['prompt_template'].format(api_id=api_id)
|
||
task_name_with_api = f"{template['name']} for {api_id}"
|
||
|
||
task_to_run = {
|
||
"name": task_name_with_api,
|
||
"prompt": final_prompt
|
||
}
|
||
|
||
# 为每个任务创建一个异步执行协程
|
||
execution_tasks.append(
|
||
execute_task(
|
||
task=task_to_run,
|
||
tool_to_session_map=tool_to_session_map,
|
||
openai_tools=openai_tools
|
||
)
|
||
)
|
||
|
||
# 并发执行所有生成的任务
|
||
await asyncio.gather(*execution_tasks)
|
||
print("\nAll generated tasks have concluded.")
|
||
|
||
# 5. 最终总结
|
||
print("\n" + "="*20 + " Phase 5: Final Summary " + "="*20)
|
||
summary_session = tool_to_session_map.get("get_test_summary")
|
||
if summary_session:
|
||
summary_result = await summary_session.call_tool("get_test_summary", {})
|
||
summary_structured = await get_structured_response(summary_result)
|
||
summary_data = summary_structured.get("result", summary_structured)
|
||
print("Final Test Summary:")
|
||
print(json.dumps(summary_data, indent=2, ensure_ascii=False))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
asyncio.run(main())
|
||
except KeyboardInterrupt:
|
||
print("\nAgent manually interrupted.")
|
||
except Exception as e:
|
||
print(f"\nAn unexpected error occurred in main: {e}")
|
||
traceback.print_exc() |