103 lines
4.4 KiB
Python
103 lines
4.4 KiB
Python
import openai
|
||
from typing import List, Dict, Any, Tuple
|
||
import json
|
||
import traceback
|
||
import os
|
||
|
||
class LLMService:
|
||
def __init__(self, model_name="qwen-plus", tools=None):
|
||
if tools is None:
|
||
tools = []
|
||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||
if not self.api_key:
|
||
raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
|
||
|
||
self.client = openai.OpenAI(api_key=self.api_key, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
self.model_name = model_name
|
||
self.tools = tools
|
||
self.system_prompt = {"role": "system", "content": "你是一个智能API测试Agent。你的任务是根据用户的要求,通过自主、连续地调用给定的工具来完成API的自动化测试。请仔细分析每一步的结果,并决定下一步应该调用哪个工具。"}
|
||
self.messages: List[Dict[str, Any]] = [self.system_prompt]
|
||
|
||
def start_new_task(self, task_description: str):
|
||
"""
|
||
开始一个新任务,这会重置对话历史,但保留最后的工具调用(如果有的话),
|
||
以提供任务切换的上下文。
|
||
"""
|
||
print(f"\n{'='*25} Starting New Task Context {'='*25}")
|
||
print(f"Task Description: {task_description}")
|
||
|
||
last_tool_call_response = None
|
||
if len(self.messages) > 1 and self.messages[-1]["role"] == "tool":
|
||
last_tool_call_response = self.messages[-1]
|
||
|
||
self.messages = [self.system_prompt]
|
||
if last_tool_call_response:
|
||
self.messages.append(last_tool_call_response)
|
||
print(f"Preserving last tool response for context: {last_tool_call_response['name']}")
|
||
|
||
self.add_user_message(task_description)
|
||
print(f"{'='*72}\n")
|
||
|
||
def add_user_message(self, content: str):
|
||
self.messages.append({"role": "user", "content": content})
|
||
|
||
def add_tool_call_response(self, tool_call_id: str, content: str):
|
||
self.messages.append(
|
||
{
|
||
"tool_call_id": tool_call_id,
|
||
"role": "tool",
|
||
"name": tool_call_id, # 名字可以和ID一样,重要的是ID要匹配
|
||
"content": content,
|
||
}
|
||
)
|
||
|
||
def get_last_assistant_message(self) -> str:
|
||
for msg in reversed(self.messages):
|
||
if msg["role"] == "assistant" and msg.get("content"):
|
||
return msg["content"]
|
||
return "No final response from assistant."
|
||
|
||
def execute_completion(self) -> Tuple[str, dict]:
|
||
print("\n" + "="*25 + " LLM Request " + "="*25)
|
||
print(json.dumps({"model": self.model_name, "messages": self.messages, "tools": self.tools}, ensure_ascii=False, indent=2))
|
||
print("="*71)
|
||
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=self.messages,
|
||
tools=self.tools,
|
||
tool_choice="auto",
|
||
)
|
||
|
||
print("\n" + "="*25 + " LLM Response " + "="*25)
|
||
print(response)
|
||
print("="*72)
|
||
|
||
response_message = response.choices[0].message
|
||
tool_calls = response_message.tool_calls
|
||
|
||
if tool_calls:
|
||
tool_call = tool_calls[0]
|
||
tool_name = tool_call.function.name
|
||
tool_call_id = tool_call.id
|
||
|
||
try:
|
||
tool_args = json.loads(tool_call.function.arguments)
|
||
# 将成功的tool_call添加到历史记录
|
||
self.messages.append(response_message.model_dump(exclude_unset=True))
|
||
return tool_name, tool_args, tool_call_id
|
||
except json.JSONDecodeError as e:
|
||
error_msg = f"LLM generated malformed JSON for tool arguments: {tool_call.function.arguments}. Error: {e}"
|
||
print(error_msg)
|
||
# 不将错误的assistant消息加入历史,而是返回错误信号
|
||
return "error_malformed_json", {"tool_name": tool_name, "malformed_arguments": tool_call.function.arguments, "error": str(e)}, None
|
||
|
||
# 如果没有工具调用,就将assistant的回复加入历史
|
||
self.messages.append(response_message.model_dump(exclude_unset=True))
|
||
return None, None, None
|
||
|
||
except Exception as e:
|
||
print(f"Error calling LLM API: {e}")
|
||
traceback.print_exc()
|
||
return None, None, None |