import os
import sys
import json
import logging
import argparse
import traceback # 用于更详细的错误日志
import uuid # For unique filenames
from pathlib import Path
import sqlite3 # <-- ADDED: For SQLite database
from werkzeug.security import generate_password_hash, check_password_hash # <-- ADDED: For password hashing
from werkzeug.utils import secure_filename # For safe filenames
from flask import Flask, request, jsonify, send_from_directory, session, redirect, url_for, render_template_string, g, flash # <-- MODIFIED: Added flash
from flask_cors import CORS # 用于处理跨域请求
from urllib.parse import urlparse
# 将ddms_compliance_suite的父目录添加到sys.path
# 假设flask_app.py与ddms_compliance_suite目录在同一级别,或者ddms_compliance_suite在其PYTHONPATH中
# 如果 ddms_compliance_suite 是一个已安装的包,则不需要这个
# current_dir = os.path.dirname(os.path.abspath(__file__))
# project_root = os.path.dirname(current_dir) # 假设项目根目录是上一级
# sys.path.insert(0, project_root)
# 或者更具体地添加包含ddms_compliance_suite的目录
# sys.path.insert(0, os.path.join(project_root, 'ddms_compliance_suite'))
from ddms_compliance_suite.api_caller.caller import APICallDetail
from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary
from ddms_compliance_suite.input_parser.parser import InputParser, ParsedYAPISpec, ParsedSwaggerSpec
# 从 run_api_tests.py 导入辅助函数 (如果它们被重构为可导入的)
# 为了简单起见,我们可能会直接在 flask_app.py 中重新实现一些逻辑或直接调用Orchestrator
app = Flask(__name__, static_folder='static', static_url_path='')
CORS(app) # 允许所有来源的跨域请求,生产环境中应配置更严格的规则
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 获取 flask_app.py 脚本所在的目录
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
# 如果是打包后的环境,可能需要使用sys._MEIPASS
if hasattr(sys, '_MEIPASS'):
APP_ROOT = sys._MEIPASS
UPLOAD_FOLDER = os.path.join(APP_ROOT, 'uploads')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
DATABASE = os.path.join(APP_ROOT, 'users.db') # <-- ADDED: Database path
app.config['SECRET_KEY'] = os.urandom(24) # <-- ADDED: Secret key for session management
app.config['DATABASE'] = DATABASE
# 数据库Schema定义,直接嵌入到代码中
DB_SCHEMA = '''
DROP TABLE IF EXISTS user;
CREATE TABLE user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL
);
'''
# --- 数据库辅助函数 ---
def get_db():
db = getattr(g, '_database', None)
if db is None:
db = g._database = sqlite3.connect(DATABASE)
db.row_factory = sqlite3.Row # Access columns by name
return db
@app.teardown_appcontext
def close_connection(exception):
db = getattr(g, '_database', None)
if db is not None:
db.close()
def init_db(force_create=False):
"""Initializes the database from schema.sql."""
if force_create or not os.path.exists(DATABASE):
with app.app_context():
db = get_db()
# 直接使用内置的DB_SCHEMA
db.cursor().executescript(DB_SCHEMA)
db.commit()
logger.info("Database initialized!")
else:
logger.info("Database already exists.")
def create_default_user(username="admin", password="admin123"):
"""创建默认用户,如果不存在"""
with app.app_context():
db = get_db()
# 检查用户是否已存在
user = db.execute('SELECT * FROM user WHERE username = ?', (username,)).fetchone()
if user is None:
# 创建默认用户
db.execute(
"INSERT INTO user (username, password_hash) VALUES (?, ?)",
(username, generate_password_hash(password))
)
db.commit()
logger.info(f"已创建默认用户: {username}")
else:
logger.info(f"默认用户 {username} 已存在")
@app.cli.command('init-db')
def init_db_command():
"""CLI command to initialize the database."""
init_db(force_create=True)
print("Initialized the database.")
# --- 用户认证路由 ---
# 注册模板已禁用
"""
REGISTER_TEMPLATE = '''
注册
注册新用户
{% with messages = get_flashed_messages() %}
{% if messages %}
{% for message in messages %}
- {{ message }}
{% endfor %}
{% endif %}
{% endwith %}
已有账户? 点此登录
'''
"""
LOGIN_TEMPLATE = '''
登录
请登录
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
{% for category, message in messages %}
- {{ message }}
{% endfor %}
{% endif %}
{% endwith %}
'''
# 注册功能已暂时禁用
# @app.route('/register', methods=('GET', 'POST'))
# def register():
# if request.method == 'POST':
# username = request.form['username']
# password = request.form['password']
# db = get_db()
# error = None
#
# if not username:
# error = '用户名是必需的.'
# elif not password:
# error = '密码是必需的.'
#
# if error is None:
# try:
# db.execute(
# "INSERT INTO user (username, password_hash) VALUES (?, ?)",
# (username, generate_password_hash(password)),
# )
# db.commit()
# except db.IntegrityError: # Username already exists
# error = f"用户 {username} 已被注册."
# else:
# flash('注册成功! 请登录.')
# return redirect(url_for("login"))
#
# flash(error)
# return render_template_string(REGISTER_TEMPLATE)
@app.route('/login', methods=('GET', 'POST'))
def login():
if request.method == 'POST':
username = request.form['username']
password = request.form['password']
db = get_db()
error = None
user = db.execute(
'SELECT * FROM user WHERE username = ?',
(username,)
).fetchone()
if user is None:
error = '用户名不存在.'
elif not check_password_hash(user['password_hash'], password):
error = '密码错误.'
if error is None:
session.clear()
session['user_id'] = user['id']
session['username'] = user['username']
flash('登录成功!', 'success')
return redirect(url_for('serve_index'))
flash(error, 'error')
# If user is already logged in, redirect to index
if 'user_id' in session:
return redirect(url_for('serve_index'))
return render_template_string(LOGIN_TEMPLATE)
@app.route('/logout')
def logout():
session.clear()
flash('您已成功登出.')
return redirect(url_for('login'))
# --- 应用辅助函数和路由保护 ---
from functools import wraps
from flask import g, flash # ensure g and flash are imported
def login_required(view):
@wraps(view)
def wrapped_view(**kwargs):
if g.user is None:
return redirect(url_for('login'))
return view(**kwargs)
return wrapped_view
@app.before_request
def load_logged_in_user():
user_id = session.get('user_id')
if user_id is None:
g.user = None
else:
g.user = get_db().execute(
'SELECT * FROM user WHERE id = ?', (user_id,)
).fetchone()
# --- 辅助函数 ---
def save_api_call_details_to_file(api_call_details: list[APICallDetail], output_dir_path_str: str, filename: str = "api_call_details.md"):
"""将API调用详情保存到Markdown文件。"""
if not api_call_details:
logger.info("没有API调用详情可保存。")
return None
output_dir = Path(output_dir_path_str)
output_dir.mkdir(parents=True, exist_ok=True)
file_path = output_dir / filename
unique_id_counter = 0
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write("# API 调用详情记录\n\n")
for detail_obj in api_call_details:
unique_id_counter +=1
unique_id = f"api-call-{unique_id_counter}"
# Convert elapsed time from seconds to milliseconds for display
elapsed_ms = detail_obj.response_elapsed_time * 1000
f.write(f"\n")
f.write(f"{detail_obj.request_method.upper()} {detail_obj.request_url} - 状态: {detail_obj.response_status_code} - 耗时: {elapsed_ms:.2f}ms
\n\n")
# 提取URL中的接口名称
parsed_url = urlparse(detail_obj.request_url)
path_parts = parsed_url.path.split('/')
# 获取路径的最后一部分作为接口名称
api_name = parsed_url.path
# 如果是版本号,则使用前一个部分
# if api_name.replace('.', '').isdigit():
# api_name = f"{path_parts[-2]}/{api_name}"
# 添加三级标题显示接口名称
f.write(f"### 接口: {api_name}\n\n")
# 请求详情
f.write("#### 请求 (Request)\n")
f.write(f"- **Method:** `{detail_obj.request_method.upper()}`\n")
f.write(f"- **URL:** `{detail_obj.request_url}`\n")
if detail_obj.request_headers:
f.write("- **Headers:**\n")
f.write("```json\n")
f.write(json.dumps(detail_obj.request_headers, indent=2, ensure_ascii=False) + "\n")
f.write("```\n")
if detail_obj.request_params:
f.write("- **Query Parameters:**\n")
f.write("```json\n")
f.write(json.dumps(detail_obj.request_params, indent=2, ensure_ascii=False) + "\n")
f.write("```\n")
if detail_obj.request_body:
f.write("- **Request Body:**\n")
f.write("```json\n") # 假设请求体是JSON,如果不是则可能需要调整
# 尝试解析为JSON,如果失败则按原样写入
try:
body_json = json.loads(detail_obj.request_body) if isinstance(detail_obj.request_body, str) else detail_obj.request_body
f.write(json.dumps(body_json, indent=2, ensure_ascii=False) + "\n")
except json.JSONDecodeError:
f.write(str(detail_obj.request_body) + "\n") # Fallback to string
f.write("```\n")
f.write(f"- **cURL Command:**\n")
f.write("```bash\n")
f.write(detail_obj.curl_command + "\n")
f.write("```\n\n")
# 响应详情
f.write("#### 响应 (Response)\n")
f.write(f"- **Status Code:** `{detail_obj.response_status_code}`\n")
f.write(f"- **Elapsed Time:** {elapsed_ms:.2f} ms\n")
if detail_obj.response_headers:
f.write("- **Response Headers:**\n")
f.write("```json\n")
f.write(json.dumps(detail_obj.response_headers, indent=2, ensure_ascii=False) + "\n")
f.write("```\n")
if detail_obj.response_body:
f.write("- **Response Body:**\n")
# 首先尝试格式化为JSON,如果失败则保持原样
try:
# 假设 response_body 是字符串或者可以被json.loads处理的字节串
body_to_write = detail_obj.response_body
if isinstance(body_to_write, bytes):
try:
body_to_write = body_to_write.decode('utf-8')
except UnicodeDecodeError:
body_to_write = str(body_to_write) # Fallback if not UTF-8
if isinstance(body_to_write, str):
try:
parsed_json = json.loads(body_to_write)
f.write("```json\n")
f.write(json.dumps(parsed_json, indent=2, ensure_ascii=False) + "\n")
f.write("```\n")
except json.JSONDecodeError: # Not a JSON string
f.write("```text\n") # Treat as plain text
f.write(body_to_write + "\n")
f.write("```\n")
else: # Already a dict/list (shouldn't happen if APICallDetail.response_body is str/bytes)
f.write("```json\n")
f.write(json.dumps(body_to_write, indent=2, ensure_ascii=False) + "\n")
f.write("```\n")
except Exception as e_resp_body:
logger.error(f"Error processing response body for API call to {detail_obj.request_url}: {e_resp_body}")
f.write("```text\n")
f.write(f"(Error processing response body: {e_resp_body})\n")
f.write(str(detail_obj.response_body) + "\n") # Fallback
f.write("```\n")
else:
f.write("- Response Body: (empty)\n")
f.write("\n \n\n")
f.write("---\n\n") # Separator
logger.info(f"API 调用详情已成功保存到: {file_path}")
return str(file_path)
except IOError as e:
logger.error(f"保存 API 调用详情到文件时发生IO错误 '{file_path}': {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"保存 API 调用详情时发生未知错误 '{file_path}': {e}", exc_info=True)
return None
def get_orchestrator_from_config(config: dict) -> APITestOrchestrator:
"""根据配置字典实例化APITestOrchestrator"""
return APITestOrchestrator(
base_url=config.get('base_url', ''),
custom_test_cases_dir=config.get('custom_test_cases_dir'),
stages_dir=config.get('stages_dir'),
llm_api_key=config.get('llm_api_key'),
llm_base_url=config.get('llm_base_url'),
llm_model_name=config.get('llm_model_name'),
use_llm_for_request_body=config.get('use_llm_for_request_body', False),
use_llm_for_path_params=config.get('use_llm_for_path_params', False),
use_llm_for_query_params=config.get('use_llm_for_query_params', False),
use_llm_for_headers=config.get('use_llm_for_headers', False),
output_dir=config.get('output_dir') # 虽然Orchestrator内部可能不直接用它保存,但可以传入
)
# --- API 端点 ---
@app.route('/')
@login_required # Protect the main page
def serve_index():
# Initialize DB if it doesn't exist when first accessing the app
# This is a simple way, for production you might want a separate init step.
if not os.path.exists(DATABASE):
logger.info("数据库不存在,正在初始化...")
init_db(force_create=True)
logger.info("数据库初始化完成")
# 创建默认用户
create_default_user()
else:
# 检查数据库中是否有user表
try:
db = get_db()
db.execute("SELECT 1 FROM user LIMIT 1")
logger.info("数据库user表存在")
except sqlite3.OperationalError:
logger.info("数据库中没有user表,正在重新初始化...")
init_db(force_create=True)
logger.info("数据库初始化完成")
# 创建默认用户
create_default_user()
return send_from_directory(app.static_folder, 'index.html')
@app.route('/run-tests', methods=['POST'])
@login_required # Protect this endpoint
def run_tests_endpoint():
logger.info("Received request to run tests.")
output_dir = None # To hold the path for report links
temp_spec_path = None # To hold the path for the uploaded spec file
try:
# The form is now sent as multipart/form-data
config_data = request.form.to_dict()
logger.info(f"Received config: {config_data}")
# Handle file upload
if 'api_spec_file' not in request.files:
logger.error("API spec file part is missing from the request.")
return jsonify({"error": "API spec file part is missing"}), 400
file = request.files['api_spec_file']
if file.filename == '':
logger.error("No API spec file selected.")
return jsonify({"error": "No API spec file selected"}), 400
if file:
filename = secure_filename(file.filename)
# Create a unique filename to avoid conflicts
unique_filename = f"{uuid.uuid4()}_{filename}"
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(temp_spec_path)
logger.info(f"Saved uploaded spec file to {temp_spec_path}")
else:
# This case should ideally not be reached if the above checks are in place
return jsonify({"error": "Invalid file object received"}), 400
# Create orchestrator from form data
orchestrator = get_orchestrator_from_config(config_data)
# Prepare summary object
summary = TestSummary()
# Determine API spec type and parse the uploaded file
api_spec_type = config_data.get('api_spec_type', 'YAPI')
logger.info(f"API Spec Type: {api_spec_type}")
parser = InputParser()
parsed_spec = None
if api_spec_type == "YAPI":
parsed_spec = parser.parse_yapi_spec(temp_spec_path)
elif api_spec_type == "Swagger":
parsed_spec = parser.parse_swagger_spec(temp_spec_path)
if not parsed_spec:
error_msg = f"Failed to parse the uploaded {api_spec_type} file."
logger.error(error_msg)
return jsonify({"error": error_msg}), 400
logger.info("Successfully parsed API specification.")
# Execute tests from parsed spec
logger.info(f"Starting test execution from parsed {api_spec_type} spec...")
summary = orchestrator._execute_tests_from_parsed_spec(
parsed_spec=parsed_spec,
summary=summary,
categories=config_data.get('categories').split(',') if config_data.get('categories') else None,
tags=config_data.get('tags').split(',') if config_data.get('tags') else None,
custom_test_cases_dir=config_data.get('custom_test_cases_dir')
)
logger.info("Test case execution finished.")
# Execute test stages
logger.info("Starting stage execution...")
summary = orchestrator.run_stages_from_spec(
parsed_spec=parsed_spec,
summary=summary
)
logger.info("Stage execution finished.")
# Handle output and reporting
output_dir = config_data.get('output_dir', './test_reports')
if not os.path.isabs(output_dir):
output_dir = os.path.join(APP_ROOT, output_dir)
os.makedirs(output_dir, exist_ok=True)
summary_path = os.path.join(output_dir, "test_summary.json")
details_path = os.path.join(output_dir, "api_call_details.md")
with open(summary_path, 'w', encoding='utf-8') as f:
f.write(summary.to_json(pretty=True))
logger.info(f"Test summary saved to {summary_path}")
save_api_call_details_to_file(orchestrator.get_api_call_details(), output_dir)
logger.info(f"API call details saved to {details_path}")
return jsonify({
"summary": summary.to_dict(),
"summary_report_path": f"/download/test_summary.json",
"details_report_path": f"/download/api_call_details.md"
})
except Exception as e:
error_msg = f"An unexpected error occurred during testing: {traceback.format_exc()}"
logger.error(error_msg)
return jsonify({"error": error_msg}), 500
finally:
# Clean up the uploaded file
if temp_spec_path and os.path.exists(temp_spec_path):
os.remove(temp_spec_path)
logger.info(f"Cleaned up temporary file: {temp_spec_path}")
@app.route('/list-yapi-categories', methods=['POST'])
@login_required
def list_yapi_categories_endpoint():
if 'api_spec_file' not in request.files:
return jsonify({"error": "api_spec_file part is missing"}), 400
file = request.files['api_spec_file']
if file.filename == '':
return jsonify({"error": "No file selected"}), 400
temp_spec_path = None
try:
filename = secure_filename(file.filename)
unique_filename = f"{uuid.uuid4()}_{filename}"
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(temp_spec_path)
parser = InputParser()
parsed_yapi = parser.parse_yapi_spec(temp_spec_path)
if not parsed_yapi or not hasattr(parsed_yapi, 'categories') or not parsed_yapi.categories:
return jsonify({"error": "Failed to parse YAPI categories or no categories found"}), 500
categories_list = [
{
"name": cat.get('name', '未命名'),
"description": cat.get('desc') if cat.get('desc') else cat.get('description') if cat.get('description') else '无描述'
}
for cat in parsed_yapi.categories
]
return jsonify(categories_list), 200
except Exception as e:
logger.error(f"Error fetching YAPI categories: {traceback.format_exc()}")
return jsonify({"error": str(e)}), 500
finally:
if temp_spec_path and os.path.exists(temp_spec_path):
os.remove(temp_spec_path)
@app.route('/list-swagger-tags', methods=['POST'])
@login_required
def list_swagger_tags_endpoint():
if 'api_spec_file' not in request.files:
return jsonify({"error": "api_spec_file part is missing"}), 400
file = request.files['api_spec_file']
if file.filename == '':
return jsonify({"error": "No file selected"}), 400
temp_spec_path = None
try:
filename = secure_filename(file.filename)
unique_filename = f"{uuid.uuid4()}_{filename}"
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(temp_spec_path)
parser = InputParser()
parsed_swagger = parser.parse_swagger_spec(temp_spec_path)
if not parsed_swagger or not hasattr(parsed_swagger, 'tags') or not parsed_swagger.tags:
return jsonify({"error": "Failed to parse Swagger tags or no tags found"}), 500
tags_list = [
{"name": tag.get('name', '未命名'), "description": tag.get('description', '无描述')}
for tag in parsed_swagger.tags
]
return jsonify(tags_list), 200
except Exception as e:
logger.error(f"Error fetching Swagger tags: {traceback.format_exc()}")
return jsonify({"error": str(e)}), 500
finally:
if temp_spec_path and os.path.exists(temp_spec_path):
os.remove(temp_spec_path)
@app.route('/download/')
@login_required
def download_file(filepath):
"""Serve files from a designated reports directory."""
# This is a simplified download endpoint. For production, consider:
# - More robust security checks on the filepath.
# - Configuring the reports directory from a central config.
reports_base_dir = os.path.join(APP_ROOT, 'test_reports')
# Basic security check to prevent directory traversal
if '..' in filepath or os.path.isabs(filepath):
from flask import abort
abort(404)
logger.info(f"Attempting to serve file: {filepath} from directory: {reports_base_dir}")
return send_from_directory(reports_base_dir, filepath, as_attachment=True)
# 资源文件路径处理函数
def resource_path(relative_path):
"""获取资源的绝对路径,适用于开发环境和打包环境"""
try:
# PyInstaller创建临时文件夹,将路径存储在_MEIPASS中
base_path = sys._MEIPASS
except Exception:
base_path = os.path.abspath(".")
return os.path.join(base_path, relative_path)
if __name__ == '__main__':
# 注意:在生产环境中,应使用Gunicorn或uWSGI等WSGI服务器运行Flask应用
# For initial setup, you might need to run the init_db function once.
# You can do this by running flask --app flask_app init-db in your terminal
# or by uncommenting the line below for the very first run:
# 打包说明:
# 使用PyInstaller打包时,只需包含static目录
# 数据库schema已内置到代码中,不再需要外部schema.sql文件
# 命令示例:pyinstaller --add-data "static:static" flask_app.py
# 检查并初始化数据库
if not os.path.exists(DATABASE):
logger.info("数据库不存在,正在初始化...")
init_db(force_create=True)
logger.info("数据库初始化完成")
# 创建默认用户
create_default_user()
# 使用SSL证书和密钥启用HTTPS
# cert_path = os.path.join(APP_ROOT, 'ssl/cert.pem')
# key_path = os.path.join(APP_ROOT, 'ssl/key.pem')
# ssl_context = (cert_path, key_path)
# app.run(debug=True, host='0.0.0.0', port=8443, ssl_context=ssl_context)
app.run(debug=False, host='0.0.0.0', port=5050)