227 lines
10 KiB
Python
227 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
||
# ddms_compliance_suite/llm_utils/llm_service.py
|
||
import os
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Optional, Dict, Any, List
|
||
|
||
import requests
|
||
from pydantic import BaseModel, Field
|
||
from pydantic.json_schema import models_json_schema
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --- Pydantic 模型示例 (用于测试和生成Schema) ---
|
||
class SampleUserAddress(BaseModel):
|
||
street: str = Field(..., description="街道地址")
|
||
city: str = Field(..., description="城市")
|
||
zip_code: str = Field(..., description="邮政编码", pattern=r"^[0-9]{5,6}$")
|
||
|
||
class SampleUserProfile(BaseModel):
|
||
user_id: int = Field(..., description="用户ID", ge=1)
|
||
username: str = Field(..., description="用户名", min_length=3, max_length=50)
|
||
email: Optional[str] = Field(None, description="用户邮箱", pattern=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$")
|
||
is_active: bool = Field(True, description="账户是否激活")
|
||
address: Optional[SampleUserAddress] = Field(None, description="用户地址")
|
||
hobbies: Optional[list[str]] = Field(None, description="兴趣爱好列表")
|
||
|
||
# --- LLM 服务类 ---
|
||
class LLMService:
|
||
"""
|
||
封装与大语言模型 (LLM) API的交互,用于智能参数生成和验证。
|
||
目前针对通义千问兼容OpenAI的模式进行实现。
|
||
"""
|
||
def __init__(self, api_key: str, base_url: str, model_name: str = "qwen-plus"):
|
||
"""
|
||
初始化LLM服务。
|
||
|
||
Args:
|
||
api_key: 大模型服务的API Key。
|
||
base_url: 大模型服务的兼容OpenAI的基础URL。
|
||
model_name: 要使用的具体模型名称。
|
||
"""
|
||
if not api_key:
|
||
raise ValueError("API Key不能为空")
|
||
if not base_url:
|
||
raise ValueError("Base URL不能为空")
|
||
|
||
self.api_key = api_key
|
||
self.base_url = base_url.rstrip('/')
|
||
self.model_name = model_name
|
||
self.headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
logger.info(f"LLMService initialized for model '{self.model_name}' at {self.base_url}")
|
||
|
||
def _execute_chat_completion_request(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
max_tokens: int = 1024,
|
||
temperature: float = 0.7,
|
||
# TODO: Consider adding a parameter like response_format_type: Optional[str] = None
|
||
# if the LLM API supports forcing JSON output (e.g., { "type": "json_object" })
|
||
) -> Optional[str]:
|
||
"""
|
||
执行对LLM聊天补全端点的通用请求。
|
||
|
||
Args:
|
||
messages: 发送给LLM的消息列表 (例如, [{'role': 'system', 'content': '...'}, {'role': 'user', ...}]).
|
||
max_tokens: LLM生成内容的最大token数。
|
||
temperature: LLM生成的随机性。
|
||
|
||
Returns:
|
||
LLM助手返回的原始文本内容,如果发生错误或没有有效内容则返回None。
|
||
"""
|
||
payload = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
"max_tokens": max_tokens,
|
||
"temperature": temperature,
|
||
}
|
||
# if response_format_type:
|
||
# payload["response_format"] = {"type": response_format_type}
|
||
|
||
logger.debug(f"LLM API Request Payload:\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
|
||
|
||
try:
|
||
response = requests.post(f"{self.base_url}/chat/completions", headers=self.headers, json=payload, timeout=60)
|
||
response.raise_for_status() # 如果HTTP状态码表示错误,则抛出异常
|
||
|
||
response_data = response.json()
|
||
logger.debug(f"LLM API Response:\n{json.dumps(response_data, indent=2, ensure_ascii=False)}")
|
||
|
||
if response_data.get("choices") and len(response_data["choices"]) > 0 and response_data["choices"][0].get("message"):
|
||
assistant_message = response_data["choices"][0]["message"]
|
||
assistant_response_content = assistant_message.get("content")
|
||
if assistant_response_content:
|
||
return assistant_response_content
|
||
else:
|
||
logger.warning("LLM响应中消息内容为空。")
|
||
else:
|
||
logger.warning(f"LLM响应格式不符合预期或choices为空: {response_data}")
|
||
|
||
except requests.exceptions.RequestException as e_req:
|
||
logger.error(f"请求LLM API时发生网络错误: {e_req}", exc_info=True)
|
||
except Exception as e:
|
||
logger.error(f"执行LLM聊天补全请求时发生未知错误: {e}", exc_info=True)
|
||
|
||
return None
|
||
|
||
def generate_parameters_from_schema(
|
||
self,
|
||
pydantic_model_class: type[BaseModel],
|
||
prompt_instructions: Optional[str] = None,
|
||
max_tokens: int = 1024,
|
||
temperature: float = 0.7
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据给定的Pydantic模型类生成JSON Schema,并调用LLM生成符合该Schema的参数字典。
|
||
"""
|
||
try:
|
||
# 1. 从Pydantic模型生成JSON Schema
|
||
model_schema = pydantic_model_class.model_json_schema(ref_template='{model}')
|
||
main_model_name = pydantic_model_class.__name__
|
||
schema_str = json.dumps(model_schema, indent=2, ensure_ascii=False)
|
||
logger.debug(f"Generated JSON Schema for '{main_model_name}':\n{schema_str}")
|
||
|
||
# 2. 构建Prompt
|
||
system_prompt = (
|
||
"你是一个API测试数据生成助手。你的任务是根据用户提供的JSON Schema和额外指令,"
|
||
"生成一个符合该Schema的JSON对象。请确保你的输出严格是一个JSON对象,"
|
||
"不包含任何额外的解释、注释或Markdown标记。"
|
||
)
|
||
user_prompt_content = f"请为以下JSON Schema生成一个有效的JSON对象实例:\n\n```json\n{schema_str}\n```\n"
|
||
if prompt_instructions:
|
||
user_prompt_content += f"\n请遵循以下额外指令:\n{prompt_instructions}"
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt_content}
|
||
]
|
||
|
||
# 3. 调用通用LLM请求方法获取原始响应
|
||
assistant_response_content = self._execute_chat_completion_request(
|
||
messages=messages,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
# 4. 解析响应
|
||
if assistant_response_content:
|
||
# 尝试从返回内容中提取JSON部分
|
||
json_match = re.search(r'```json\n(.*?)\n```', assistant_response_content, re.DOTALL)
|
||
if json_match:
|
||
json_str = json_match.group(1)
|
||
else:
|
||
first_brace = assistant_response_content.find('{')
|
||
last_brace = assistant_response_content.rfind('}')
|
||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||
json_str = assistant_response_content[first_brace : last_brace+1]
|
||
else:
|
||
json_str = assistant_response_content
|
||
|
||
try:
|
||
generated_params = json.loads(json_str)
|
||
pydantic_model_class.model_validate(generated_params)
|
||
logger.info(f"成功从LLM生成并验证了 '{main_model_name}' 的参数。")
|
||
return generated_params
|
||
except json.JSONDecodeError as e_json:
|
||
logger.error(f"无法将LLM响应解析为JSON: {e_json}\n原始响应片段: '{json_str[:500]}'")
|
||
except Exception as e_val: # Pydantic ValidationError
|
||
logger.error(f"LLM生成的参数未能通过Pydantic模型验证: {e_val}\n生成的参数: {json_str}")
|
||
else:
|
||
logger.warning("从LLM获取的响应内容为空或请求失败。")
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行LLM参数生成时发生未知错误: {e}", exc_info=True)
|
||
|
||
return None
|
||
|
||
# --- 示例用法 (用于模块内测试) ---
|
||
if __name__ == '__main__':
|
||
logging.basicConfig(level=logging.DEBUG,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
|
||
DASH_API_KEY = os.environ.get("DASHSCOPE_API_KEY") or "sk-YOUR_DASHSCOPE_API_KEY"
|
||
DASH_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
|
||
if "YOUR_DASHSCOPE_API_KEY" in DASH_API_KEY:
|
||
logger.warning("请将 DASH_API_KEY 替换为您的有效API密钥,或设置 DASHSCOPE_API_KEY 环境变量。")
|
||
|
||
llm_service_instance = LLMService(api_key=DASH_API_KEY, base_url=DASH_BASE_URL)
|
||
|
||
logger.info("\n--- 测试 SampleUserProfile 参数生成 ---")
|
||
generated_profile = llm_service_instance.generate_parameters_from_schema(
|
||
pydantic_model_class=SampleUserProfile,
|
||
prompt_instructions="请生成一个表示非活跃用户的配置文件,用户名包含 \"test_user\" 字样,城市为上海,并包含至少一个兴趣爱好。"
|
||
)
|
||
|
||
if generated_profile:
|
||
logger.info(f"成功生成的 UserProfile:\n{json.dumps(generated_profile, indent=2, ensure_ascii=False)}")
|
||
try:
|
||
SampleUserProfile.model_validate(generated_profile)
|
||
logger.info("生成的UserProfile通过了Pydantic验证。")
|
||
except Exception as e:
|
||
logger.error(f"生成的UserProfile未能通过Pydantic验证: {e}")
|
||
else:
|
||
logger.warning("未能生成UserProfile。")
|
||
|
||
logger.info("\n--- 测试 SampleUserAddress 参数生成 ---")
|
||
generated_address = llm_service_instance.generate_parameters_from_schema(
|
||
pydantic_model_class=SampleUserAddress,
|
||
prompt_instructions="生成一个位于北京市朝阳区的地址,邮编以1000开头。"
|
||
)
|
||
if generated_address:
|
||
logger.info(f"成功生成的 UserAddress:\n{json.dumps(generated_address, indent=2, ensure_ascii=False)}")
|
||
else:
|
||
logger.warning("未能生成UserAddress。")
|
||
|
||
logger.info("\n--- 测试 SampleUserProfile 无额外指令 ---")
|
||
generated_profile_no_instr = llm_service_instance.generate_parameters_from_schema(
|
||
pydantic_model_class=SampleUserProfile
|
||
)
|
||
if generated_profile_no_instr:
|
||
logger.info(f"成功生成的 (无指令) UserProfile:\n{json.dumps(generated_profile_no_instr, indent=2, ensure_ascii=False)}")
|
||
else:
|
||
logger.warning("未能生成 (无指令) UserProfile。") |