compliance/compliance-mcp-agent/agent_main_loop.py
gongwenxin 1901cf611e 集成
2025-07-24 17:22:36 +08:00

253 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import traceback
from typing import List, Dict, Any
from mcp.client.streamable_http import streamablehttp_client
from mcp import ClientSession
from mcp.types import Tool, TextContent
from llm.llm_service import LLMService
# --- 配置区 ---
SERVER_ENDPOINTS = {
"api_caller": "http://127.0.0.1:8001/mcp",
"schema_validator": "http://127.0.0.1:8002/mcp",
"dms_provider": "http://127.0.0.1:8003/mcp",
"test_manager": "http://127.0.0.1:8004/mcp",
}
MAX_AGENT_LOOPS = 50
def mcp_tools_to_openai_format(mcp_tools: List[Tool]) -> List[Dict[str, Any]]:
"""
将MCP工具列表转换为OpenAI工具格式。
"""
openai_tools = []
for tool in mcp_tools:
# tool is a mcp.types.Tool object, which has .name, .description, and .inputSchema
openai_tools.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.inputSchema or {"type": "object", "properties": {}}
}
})
return openai_tools
async def get_structured_response(tool_response: Any) -> Dict[str, Any]:
"""
健壮地从工具调用响应中获取结构化内容。
能处理SDK未能自动解析JSON而是将其放入TextContent的情况。
"""
if tool_response.structuredContent:
# 正常情况SDK已成功解析
return tool_response.structuredContent
# 异常情况尝试从TextContent手动解析JSON
if tool_response.content and isinstance(tool_response.content[0], TextContent):
try:
json_text = tool_response.content[0].text
parsed_json = json.loads(json_text)
return parsed_json
except (json.JSONDecodeError, IndexError) as e:
# 如果手动解析也失败,则抛出致命错误
raise RuntimeError(f"Failed to manually parse JSON from TextContent: {e}. Raw text: '{json_text}'")
# 如果既没有structuredContent也没有可解析的TextContent则抛出致命错误
raise RuntimeError("Tool call returned no structuredContent and no parsable TextContent.")
async def execute_task(task: Dict, tool_to_session_map: Dict, openai_tools: List[Dict]):
"""
为一个通用的、由prompt驱动的任务执行完整的、隔离的测试生命周期。
"""
llm_service = LLMService(tools=openai_tools)
task_name = task['name']
prompt = task['prompt']
print(f"\n>>>> Starting Task: {task_name} <<<<")
llm_service.start_new_task(prompt)
# 针对当前任务的子任务循环
for sub_loop in range(25): # 单个任务的测试循环上限
print("\n" + "="*20 + f" Sub-Loop for '{task_name}' ({sub_loop+1}/25) " + "="*20)
tool_name, tool_args, tool_call_id = llm_service.execute_completion()
if not tool_name:
print(f"Agent: LLM did not request a tool call for task '{task_name}'. It might be confused. Ending task.")
# 即使LLM困惑我们仍然尝试记录一个失败结果如果record_test_result可用的话
record_session = tool_to_session_map.get("record_test_result")
if record_session:
# 我们需要从prompt中猜测api_id这很脆弱但比什么都不做要好
import re
match = re.search(r"API 模型 '([^']+)'", prompt)
api_id_guess = match.group(1) if match else "unknown"
await record_session.call_tool("record_test_result", {"api_id": api_id_guess, "task_name": task_name, "status": "failed", "details": "LLM got confused and stopped calling tools."})
return # 结束此任务
# 核心逻辑如果LLM调用了record_test_result说明这个任务结束了
if tool_name == "record_test_result":
print(f"Agent: LLM is recording result for task '{task_name}'. Task is complete.")
record_session = tool_to_session_map.get("record_test_result")
if record_session:
# 将任务名称加入到参数中,以便更好地跟踪
tool_args['task_name'] = task_name
await record_session.call_tool(tool_name, tool_args)
return # 核心修复使用return退出此任务的函数
if tool_name == "error_malformed_json":
error_info = tool_args
print(f"Agent: Detected a malformed JSON from LLM for tool '{error_info['tool_name']}'. Asking for correction.")
correction_request = f"你上次试图调用工具 '{error_info['tool_name']}'但提供的参数不是一个有效的JSON。错误是{error_info['error']}。这是你提供的错误参数:'{error_info['malformed_arguments']}'。请修正这个错误,并重新调用该工具。"
llm_service.add_user_message(correction_request)
continue
if tool_name in tool_to_session_map:
try:
target_session = tool_to_session_map[tool_name]
result = await target_session.call_tool(tool_name, tool_args)
structured_result = await get_structured_response(result)
tool_result_str = json.dumps(structured_result, ensure_ascii=False, indent=2) if structured_result else "Tool executed successfully."
print(f"Agent: Tool '{tool_name}' executed for '{task_name}'. Result: {tool_result_str}")
llm_service.add_tool_call_response(tool_call_id, tool_result_str)
except Exception as e:
error_message = f"An exception occurred while calling tool {tool_name} for '{task_name}': {e}"
print(f"Agent: {error_message}")
traceback.print_exc()
llm_service.add_tool_call_response(tool_call_id, error_message)
else:
error_message = f"Error: LLM tried to call an unknown tool '{tool_name}' for task '{task_name}'."
print(f"Agent: {error_message}")
llm_service.add_tool_call_response(tool_call_id, error_message)
print(f"Agent: Reached sub-loop limit for task '{task_name}'. Recording as failed and moving on.")
record_session = tool_to_session_map.get("record_test_result")
if record_session:
import re
match = re.search(r"API 模型 '([^']+)'", prompt)
api_id_guess = match.group(1) if match else "unknown"
await record_session.call_tool("record_test_result", {"api_id": api_id_guess, "task_name": task_name, "status": "failed", "details": "Reached sub-loop limit."})
async def main():
print("LLM-Powered Agent starting...")
# 使用 `async with` 来确保所有会话都能被正确关闭
async with streamablehttp_client(SERVER_ENDPOINTS["api_caller"]) as (r1, w1, _), \
streamablehttp_client(SERVER_ENDPOINTS["schema_validator"]) as (r2, w2, _), \
streamablehttp_client(SERVER_ENDPOINTS["dms_provider"]) as (r3, w3, _), \
streamablehttp_client(SERVER_ENDPOINTS["test_manager"]) as (r4, w4, _):
print("Agent: All MCP server connections established.")
async with ClientSession(r1, w1) as s1, ClientSession(r2, w2) as s2, ClientSession(r3, w3) as s3, ClientSession(r4, w4) as s4:
await asyncio.gather(s1.initialize(), s2.initialize(), s3.initialize(), s4.initialize())
tool_to_session_map = {tool.name: s1 for tool in (await s1.list_tools()).tools}
tool_to_session_map.update({tool.name: s2 for tool in (await s2.list_tools()).tools})
tool_to_session_map.update({tool.name: s3 for tool in (await s3.list_tools()).tools})
tool_to_session_map.update({tool.name: s4 for tool in (await s4.list_tools()).tools})
all_mcp_tools = list(tool_to_session_map.keys())
print(f"Total tools found: {len(all_mcp_tools)}")
openai_tools = mcp_tools_to_openai_format([tool for session in [s1, s2, s3, s4] for tool in (await session.list_tools()).tools])
print("Agent: LLM Service tools prepared.")
# --- Agent主导的宏观测试流程 ---
# 1. 获取所有待测试的API
print("\n" + "="*20 + " Phase 1: Fetching APIs " + "="*20)
get_api_list_session = tool_to_session_map.get("get_api_list")
if not get_api_list_session:
raise RuntimeError("Critical Error: 'get_api_list' tool not found.")
api_list_result = await get_api_list_session.call_tool("get_api_list", {})
api_list_structured = await get_structured_response(api_list_result)
response_data = api_list_structured.get("result", api_list_structured)
api_records = response_data.get('records', [])
api_ids_to_test = [record['id'] for record in api_records if 'id' in record]
if not api_ids_to_test:
raise RuntimeError(f"Critical Error: DMSProviderServer returned an empty list of APIs.")
print(f"Agent: Found {len(api_ids_to_test)} APIs to test: {api_ids_to_test}")
# 2. 加载任务模板
print("\n" + "="*20 + " Phase 2: Loading Task Templates " + "="*20)
try:
with open('compliance-mcp-agent/tasks.json', 'r', encoding='utf-8') as f:
task_templates = json.load(f)
print(f"Agent: Loaded {len(task_templates)} task templates.")
except FileNotFoundError:
raise RuntimeError("Critical Error: 'tasks.json' not found in 'compliance-mcp-agent/' directory.")
except json.JSONDecodeError as e:
raise RuntimeError(f"Critical Error: Failed to parse 'tasks.json'. Error: {e}")
# 3. 初始化测试计划
print("\n" + "="*20 + " Phase 3: Initializing Test Plan " + "="*20)
initialize_plan_session = tool_to_session_map.get("initialize_test_plan")
if not initialize_plan_session:
raise RuntimeError("Critical Error: 'initialize_test_plan' tool not found.")
total_task_count = len(api_ids_to_test) * len(task_templates)
print(f"Agent: Initializing test plan for {total_task_count} total tasks ({len(api_ids_to_test)} APIs x {len(task_templates)} templates)...")
init_result = await initialize_plan_session.call_tool("initialize_test_plan", {"api_ids": api_ids_to_test})
init_structured = await get_structured_response(init_result)
init_response_data = init_structured.get("result", init_structured)
if init_response_data.get("status") != "success":
raise RuntimeError(f"Failed to initialize test plan. Reason: {init_response_data.get('message')}")
print("Agent: Test plan initialized successfully in TestManager.")
# 4. 主执行循环 (M x N)
print("\n" + "="*20 + " Phase 4: Main Execution Loop " + "="*20)
execution_tasks = []
for api_id in api_ids_to_test:
for template in task_templates:
# 动态生成任务
final_prompt = template['prompt_template'].format(api_id=api_id)
task_name_with_api = f"{template['name']} for {api_id}"
task_to_run = {
"name": task_name_with_api,
"prompt": final_prompt
}
# 为每个任务创建一个异步执行协程
execution_tasks.append(
execute_task(
task=task_to_run,
tool_to_session_map=tool_to_session_map,
openai_tools=openai_tools
)
)
# 并发执行所有生成的任务
await asyncio.gather(*execution_tasks)
print("\nAll generated tasks have concluded.")
# 5. 最终总结
print("\n" + "="*20 + " Phase 5: Final Summary " + "="*20)
summary_session = tool_to_session_map.get("get_test_summary")
if summary_session:
summary_result = await summary_session.call_tool("get_test_summary", {})
summary_structured = await get_structured_response(summary_result)
summary_data = summary_structured.get("result", summary_structured)
print("Final Test Summary:")
print(json.dumps(summary_data, indent=2, ensure_ascii=False))
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nAgent manually interrupted.")
except Exception as e:
print(f"\nAn unexpected error occurred in main: {e}")
traceback.print_exc()