2025-06-01 00:38:36 +08:00

285 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
# ddms_compliance_suite/llm_utils/llm_service.py
import os
import json
import logging
import re
from typing import Optional, Dict, Any, List
import requests
from pydantic import BaseModel, Field
from pydantic.json_schema import models_json_schema
logger = logging.getLogger(__name__)
# --- Pydantic 模型示例 (用于测试和生成Schema) ---
class SampleUserAddress(BaseModel):
street: str = Field(..., description="街道地址")
city: str = Field(..., description="城市")
zip_code: str = Field(..., description="邮政编码", pattern=r"^[0-9]{5,6}$")
class SampleUserProfile(BaseModel):
user_id: int = Field(..., description="用户ID", ge=1)
username: str = Field(..., description="用户名", min_length=3, max_length=50)
email: Optional[str] = Field(None, description="用户邮箱", pattern=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$")
is_active: bool = Field(True, description="账户是否激活")
address: Optional[SampleUserAddress] = Field(None, description="用户地址")
hobbies: Optional[list[str]] = Field(None, description="兴趣爱好列表")
# --- LLM 服务类 ---
class LLMService:
"""
封装与大语言模型 (LLM) API的交互用于智能参数生成和验证。
目前针对通义千问兼容OpenAI的模式进行实现。
"""
def __init__(self, api_key: str, base_url: str, model_name: str = "qwen-plus"):
"""
初始化LLM服务。
Args:
api_key: 大模型服务的API Key。
base_url: 大模型服务的兼容OpenAI的基础URL。
model_name: 要使用的具体模型名称。
"""
if not api_key:
raise ValueError("API Key不能为空")
if not base_url:
raise ValueError("Base URL不能为空")
self.api_key = api_key
self.base_url = base_url.rstrip('/')
self.model_name = model_name
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
logger.info(f"LLMService initialized for model '{self.model_name}' at {self.base_url}")
def _execute_chat_completion_request(
self,
messages: List[Dict[str, str]],
max_tokens: int = 1024,
temperature: float = 0.1,
# TODO: Consider adding a parameter like response_format_type: Optional[str] = None
# if the LLM API supports forcing JSON output (e.g., { "type": "json_object" })
) -> Optional[str]:
"""
执行对LLM聊天补全端点的通用请求。
Args:
messages: 发送给LLM的消息列表 (例如, [{'role': 'system', 'content': '...'}, {'role': 'user', ...}]).
max_tokens: LLM生成内容的最大token数。
temperature: LLM生成的随机性。
Returns:
LLM助手返回的原始文本内容如果发生错误或没有有效内容则返回None。
"""
payload = {
"model": self.model_name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
# if response_format_type:
# payload["response_format"] = {"type": response_format_type}
logger.debug(f"LLM API Request Payload:\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
try:
response = requests.post(f"{self.base_url}/chat/completions", headers=self.headers, json=payload, timeout=60)
response.raise_for_status() # 如果HTTP状态码表示错误则抛出异常
response_data = response.json()
logger.debug(f"LLM API Response:\n{json.dumps(response_data, indent=2, ensure_ascii=False)}")
if response_data.get("choices") and len(response_data["choices"]) > 0 and response_data["choices"][0].get("message"):
assistant_message = response_data["choices"][0]["message"]
assistant_response_content = assistant_message.get("content")
if assistant_response_content:
return assistant_response_content
else:
logger.warning("LLM响应中消息内容为空。")
else:
logger.warning(f"LLM响应格式不符合预期或choices为空: {response_data}")
except requests.exceptions.RequestException as e_req:
logger.error(f"请求LLM API时发生网络错误: {e_req}", exc_info=True)
except Exception as e:
logger.error(f"执行LLM聊天补全请求时发生未知错误: {e}", exc_info=True)
return None
def generate_parameters_from_schema(
self,
pydantic_model_class: type[BaseModel],
prompt_instruction: Optional[str] = None,
max_tokens: int = 1024,
temperature: float = 0.1
) -> Optional[Dict[str, Any]]:
"""
根据给定的Pydantic模型类生成JSON Schema并调用LLM生成符合该Schema的参数字典。
"""
try:
# 1. 从Pydantic模型生成JSON Schema
model_schema = pydantic_model_class.model_json_schema(ref_template='{model}')
main_model_name = pydantic_model_class.__name__
schema_str = json.dumps(model_schema, indent=2, ensure_ascii=False)
logger.debug(f"Generated JSON Schema for '{main_model_name}':\n{schema_str}")
# 2. 构建Prompt
system_prompt = (
"你是一个API测试数据生成助手。你的任务是根据用户提供的JSON Schema和额外指令"
"生成一个符合该Schema的JSON对象。请确保你的输出严格是一个JSON对象"
"不包含任何额外的解释、注释或Markdown标记。"
)
user_prompt_content = f"请为以下JSON Schema生成一个有效的JSON对象实例:\n\n```json\n{schema_str}\n```\n"
if prompt_instruction:
user_prompt_content += f"\n请遵循以下额外指令:\n{prompt_instruction}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_content}
]
# 3. 调用通用LLM请求方法获取原始响应
assistant_response_content = self._execute_chat_completion_request(
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
# 4. 解析响应
if assistant_response_content:
# 尝试从返回内容中提取JSON部分
json_match = re.search(r'```json\n(.*?)\n```', assistant_response_content, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
first_brace = assistant_response_content.find('{')
last_brace = assistant_response_content.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
json_str = assistant_response_content[first_brace : last_brace+1]
else:
json_str = assistant_response_content
try:
generated_params = json.loads(json_str)
pydantic_model_class.model_validate(generated_params)
logger.info(f"成功从LLM生成并验证了 '{main_model_name}' 的参数。")
return generated_params
except json.JSONDecodeError as e_json:
logger.error(f"无法将LLM响应解析为JSON: {e_json}\n原始响应片段: '{json_str[:500]}'")
except Exception as e_val: # Pydantic ValidationError
logger.error(f"LLM生成的参数未能通过Pydantic模型验证: {e_val}\n生成的参数: {json_str}")
else:
logger.warning("从LLM获取的响应内容为空或请求失败。")
except Exception as e:
logger.error(f"执行LLM参数生成时发生未知错误: {e}", exc_info=True)
return None
def generate_data_from_schema(
self,
schema_dict: dict,
prompt_instruction: Optional[str] = None,
max_tokens: int = 1024,
temperature: float = 0.1
) -> Optional[Dict[str, Any]]:
"""
根据给定的JSON Schema字典调用LLM生成符合该Schema的数据对象。
"""
try:
schema_str = json.dumps(schema_dict, indent=2, ensure_ascii=False)
logger.debug(f"LLMService.generate_data_from_schema: 使用的JSON Schema:\n{schema_str}")
system_prompt = (
"你是一个API测试数据生成助手。你的任务是根据用户提供的JSON Schema和额外指令"
"生成一个符合该Schema的JSON对象。请确保你的输出严格是一个JSON对象"
"不包含任何额外的解释、注释或Markdown标记。"
)
user_prompt_content = f"请为以下JSON Schema生成一个有效的JSON对象实例:\n\n```json\n{schema_str}\n```\n"
if prompt_instruction:
user_prompt_content += f"\n请遵循以下额外指令:\n{prompt_instruction}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_content}
]
assistant_response_content = self._execute_chat_completion_request(
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
if assistant_response_content:
# 尝试从返回内容中提取JSON部分
json_match = re.search(r'```json\n(.*?)\n```', assistant_response_content, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
first_brace = assistant_response_content.find('{')
last_brace = assistant_response_content.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
json_str = assistant_response_content[first_brace : last_brace+1]
else:
json_str = assistant_response_content
try:
generated_data = json.loads(json_str)
logger.info("成功从LLM生成并解析了数据。")
return generated_data
except json.JSONDecodeError as e_json:
logger.error(f"无法将LLM响应解析为JSON: {e_json}\n原始响应片段: '{json_str[:500]}'")
else:
logger.warning("从LLM获取的响应内容为空或请求失败。")
except Exception as e:
logger.error(f"执行LLM数据生成时发生未知错误: {e}", exc_info=True)
return None
# --- 示例用法 (用于模块内测试) ---
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
DASH_API_KEY = os.environ.get("DASHSCOPE_API_KEY") or "sk-YOUR_DASHSCOPE_API_KEY"
DASH_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
if "YOUR_DASHSCOPE_API_KEY" in DASH_API_KEY:
logger.warning("请将 DASH_API_KEY 替换为您的有效API密钥或设置 DASHSCOPE_API_KEY 环境变量。")
llm_service_instance = LLMService(api_key=DASH_API_KEY, base_url=DASH_BASE_URL)
logger.info("\n--- 测试 SampleUserProfile 参数生成 ---")
generated_profile = llm_service_instance.generate_parameters_from_schema(
pydantic_model_class=SampleUserProfile,
prompt_instruction="请生成一个表示非活跃用户的配置文件,用户名包含 \"test_user\" 字样,城市为上海,并包含至少一个兴趣爱好。"
)
if generated_profile:
logger.info(f"成功生成的 UserProfile:\n{json.dumps(generated_profile, indent=2, ensure_ascii=False)}")
try:
SampleUserProfile.model_validate(generated_profile)
logger.info("生成的UserProfile通过了Pydantic验证。")
except Exception as e:
logger.error(f"生成的UserProfile未能通过Pydantic验证: {e}")
else:
logger.warning("未能生成UserProfile。")
logger.info("\n--- 测试 SampleUserAddress 参数生成 ---")
generated_address = llm_service_instance.generate_parameters_from_schema(
pydantic_model_class=SampleUserAddress,
prompt_instruction="生成一个位于北京市朝阳区的地址邮编以1000开头。"
)
if generated_address:
logger.info(f"成功生成的 UserAddress:\n{json.dumps(generated_address, indent=2, ensure_ascii=False)}")
else:
logger.warning("未能生成UserAddress。")
logger.info("\n--- 测试 SampleUserProfile 无额外指令 ---")
generated_profile_no_instr = llm_service_instance.generate_parameters_from_schema(
pydantic_model_class=SampleUserProfile
)
if generated_profile_no_instr:
logger.info(f"成功生成的 (无指令) UserProfile:\n{json.dumps(generated_profile_no_instr, indent=2, ensure_ascii=False)}")
else:
logger.warning("未能生成 (无指令) UserProfile。")