compliance/flask_app.py
2025-06-06 14:52:08 +08:00

569 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import json
import logging
import argparse
import traceback # 用于更详细的错误日志
import uuid # For unique filenames
from pathlib import Path
import sqlite3 # <-- ADDED: For SQLite database
from werkzeug.security import generate_password_hash, check_password_hash # <-- ADDED: For password hashing
from werkzeug.utils import secure_filename # For safe filenames
from flask import Flask, request, jsonify, send_from_directory, session, redirect, url_for, render_template_string, g # <-- MODIFIED: Added session, redirect, url_for, render_template_string
from flask_cors import CORS # 用于处理跨域请求
# 将ddms_compliance_suite的父目录添加到sys.path
# 假设flask_app.py与ddms_compliance_suite目录在同一级别或者ddms_compliance_suite在其PYTHONPATH中
# 如果 ddms_compliance_suite 是一个已安装的包,则不需要这个
# current_dir = os.path.dirname(os.path.abspath(__file__))
# project_root = os.path.dirname(current_dir) # 假设项目根目录是上一级
# sys.path.insert(0, project_root)
# 或者更具体地添加包含ddms_compliance_suite的目录
# sys.path.insert(0, os.path.join(project_root, 'ddms_compliance_suite'))
from ddms_compliance_suite.api_caller.caller import APICallDetail
from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary
from ddms_compliance_suite.input_parser.parser import InputParser, ParsedYAPISpec, ParsedSwaggerSpec
# 从 run_api_tests.py 导入辅助函数 (如果它们被重构为可导入的)
# 为了简单起见,我们可能会直接在 flask_app.py 中重新实现一些逻辑或直接调用Orchestrator
app = Flask(__name__, static_folder='static', static_url_path='')
CORS(app) # 允许所有来源的跨域请求,生产环境中应配置更严格的规则
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 获取 flask_app.py 脚本所在的目录
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
UPLOAD_FOLDER = os.path.join(APP_ROOT, 'uploads')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
DATABASE = os.path.join(APP_ROOT, 'users.db') # <-- ADDED: Database path
app.config['SECRET_KEY'] = os.urandom(24) # <-- ADDED: Secret key for session management
app.config['DATABASE'] = DATABASE
# --- 数据库辅助函数 ---
def get_db():
db = getattr(g, '_database', None)
if db is None:
db = g._database = sqlite3.connect(DATABASE)
db.row_factory = sqlite3.Row # Access columns by name
return db
@app.teardown_appcontext
def close_connection(exception):
db = getattr(g, '_database', None)
if db is not None:
db.close()
def init_db(force_create=False):
"""Initializes the database from schema.sql."""
if force_create or not os.path.exists(DATABASE):
with app.app_context():
db = get_db()
with app.open_resource('schema.sql', mode='r') as f:
db.cursor().executescript(f.read())
db.commit()
logger.info("Database initialized!")
else:
logger.info("Database already exists.")
@app.cli.command('init-db')
def init_db_command():
"""CLI command to initialize the database."""
init_db(force_create=True)
print("Initialized the database.")
# --- 用户认证路由 ---
REGISTER_TEMPLATE = '''
<!DOCTYPE html>
<html lang="zh-CN">
<head><meta charset="UTF-8"><title>注册</title></head>
<body>
<h2>注册新用户</h2>
{% with messages = get_flashed_messages() %}
{% if messages %}
<ul class=flashes>
{% for message in messages %}
<li>{{ message }}</li>
{% endfor %}
</ul>
{% endif %}
{% endwith %}
<form method="post">
<label for="username">用户名:</label>
<input type="text" id="username" name="username" required><br><br>
<label for="password">密码:</label>
<input type="password" id="password" name="password" required><br><br>
<input type="submit" value="注册">
</form>
<p>已有账户? <a href="{{ url_for('login') }}">点此登录</a></p>
</body>
</html>
'''
LOGIN_TEMPLATE = '''
<!DOCTYPE html>
<html lang="zh-CN">
<head><meta charset="UTF-8"><title>登录</title></head>
<body>
<h2>请登录</h2>
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
<ul class=flashes>
{% for category, message in messages %}
<li class="{{ category }}">{{ message }}</li>
{% endfor %}
</ul>
{% endif %}
{% endwith %}
<form method="post">
<label for="username">用户名:</label>
<input type="text" id="username" name="username" required><br><br>
<label for="password">密码:</label>
<input type="password" id="password" name="password" required><br><br>
<input type="submit" value="登录">
</form>
<p>没有账户? <a href="{{ url_for('register') }}">点此注册</a></p>
</body>
</html>
'''
@app.route('/register', methods=('GET', 'POST'))
def register():
if request.method == 'POST':
username = request.form['username']
password = request.form['password']
db = get_db()
error = None
if not username:
error = '用户名是必需的.'
elif not password:
error = '密码是必需的.'
if error is None:
try:
db.execute(
"INSERT INTO user (username, password_hash) VALUES (?, ?)",
(username, generate_password_hash(password)),
)
db.commit()
except db.IntegrityError: # Username already exists
error = f"用户 {username} 已被注册."
else:
flash('注册成功! 请登录.')
return redirect(url_for("login"))
flash(error)
return render_template_string(REGISTER_TEMPLATE)
@app.route('/login', methods=('GET', 'POST'))
def login():
if request.method == 'POST':
username = request.form['username']
password = request.form['password']
db = get_db()
error = None
user = db.execute(
'SELECT * FROM user WHERE username = ?',
(username,)
).fetchone()
if user is None:
error = '用户名不存在.'
elif not check_password_hash(user['password_hash'], password):
error = '密码错误.'
if error is None:
session.clear()
session['user_id'] = user['id']
session['username'] = user['username']
flash('登录成功!', 'success')
return redirect(url_for('serve_index'))
flash(error, 'error')
# If user is already logged in, redirect to index
if 'user_id' in session:
return redirect(url_for('serve_index'))
return render_template_string(LOGIN_TEMPLATE)
@app.route('/logout')
def logout():
session.clear()
flash('您已成功登出.')
return redirect(url_for('login'))
# --- 应用辅助函数和路由保护 ---
from functools import wraps
from flask import g, flash # ensure g and flash are imported
def login_required(view):
@wraps(view)
def wrapped_view(**kwargs):
if g.user is None:
return redirect(url_for('login'))
return view(**kwargs)
return wrapped_view
@app.before_request
def load_logged_in_user():
user_id = session.get('user_id')
if user_id is None:
g.user = None
else:
g.user = get_db().execute(
'SELECT * FROM user WHERE id = ?', (user_id,)
).fetchone()
# --- 辅助函数 ---
def save_api_call_details_to_file(api_call_details: list[APICallDetail], output_dir_path_str: str, filename: str = "api_call_details.md"):
"""将API调用详情保存到Markdown文件。"""
if not api_call_details:
logger.info("没有API调用详情可保存。")
return None
output_dir = Path(output_dir_path_str)
output_dir.mkdir(parents=True, exist_ok=True)
file_path = output_dir / filename
unique_id_counter = 0
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write("# API 调用详情记录\\n\\n")
for detail_obj in api_call_details:
unique_id_counter +=1
unique_id = f"api-call-{unique_id_counter}"
# Convert elapsed time from seconds to milliseconds for display
elapsed_ms = detail_obj.response_elapsed_time * 1000
f.write(f"<details id='{unique_id}'>\\n")
f.write(f"<summary><b>{detail_obj.request_method.upper()}</b> {detail_obj.request_url} - <b>状态: {detail_obj.response_status_code}</b> - 耗时: {elapsed_ms:.2f}ms</summary>\\n\\n")
# 请求详情
f.write("#### 请求 (Request)\\n")
f.write(f"- **Method:** `{detail_obj.request_method.upper()}`\\n")
f.write(f"- **URL:** `{detail_obj.request_url}`\\n")
if detail_obj.request_headers:
f.write("- **Headers:**\\n")
f.write("```json\\n")
f.write(json.dumps(detail_obj.request_headers, indent=2, ensure_ascii=False) + "\\n")
f.write("```\\n")
if detail_obj.request_params:
f.write("- **Query Parameters:**\\n")
f.write("```json\\n")
f.write(json.dumps(detail_obj.request_params, indent=2, ensure_ascii=False) + "\\n")
f.write("```\\n")
if detail_obj.request_body:
f.write("- **Request Body:**\\n")
f.write("```json\\n") # 假设请求体是JSON如果不是则可能需要调整
# 尝试解析为JSON如果失败则按原样写入
try:
body_json = json.loads(detail_obj.request_body) if isinstance(detail_obj.request_body, str) else detail_obj.request_body
f.write(json.dumps(body_json, indent=2, ensure_ascii=False) + "\\n")
except json.JSONDecodeError:
f.write(str(detail_obj.request_body) + "\\n") # Fallback to string
f.write("```\\n")
f.write(f"- **cURL Command:**\\n")
f.write("```bash\\n")
f.write(detail_obj.curl_command + "\\n")
f.write("```\\n\\n")
# 响应详情
f.write("#### 响应 (Response)\\n")
f.write(f"- **Status Code:** `{detail_obj.response_status_code}`\\n")
f.write(f"- **Elapsed Time:** {elapsed_ms:.2f} ms\\n")
if detail_obj.response_headers:
f.write("- **Response Headers:**\\n")
f.write("```json\\n")
f.write(json.dumps(detail_obj.response_headers, indent=2, ensure_ascii=False) + "\\n")
f.write("```\\n")
if detail_obj.response_body:
f.write("- **Response Body:**\\n")
# 首先尝试格式化为JSON如果失败则保持原样
try:
# 假设 response_body 是字符串或者可以被json.loads处理的字节串
body_to_write = detail_obj.response_body
if isinstance(body_to_write, bytes):
try:
body_to_write = body_to_write.decode('utf-8')
except UnicodeDecodeError:
body_to_write = str(body_to_write) # Fallback if not UTF-8
if isinstance(body_to_write, str):
try:
parsed_json = json.loads(body_to_write)
f.write("```json\\n")
f.write(json.dumps(parsed_json, indent=2, ensure_ascii=False) + "\\n")
f.write("```\\n")
except json.JSONDecodeError: # Not a JSON string
f.write("```text\\n") # Treat as plain text
f.write(body_to_write + "\\n")
f.write("```\\n")
else: # Already a dict/list (shouldn't happen if APICallDetail.response_body is str/bytes)
f.write("```json\\n")
f.write(json.dumps(body_to_write, indent=2, ensure_ascii=False) + "\\n")
f.write("```\\n")
except Exception as e_resp_body:
logger.error(f"Error processing response body for API call to {detail_obj.request_url}: {e_resp_body}")
f.write("```text\\n")
f.write(f"(Error processing response body: {e_resp_body})\\n")
f.write(str(detail_obj.response_body) + "\\n") # Fallback
f.write("```\\n")
else:
f.write("- Response Body: (empty)\\n")
f.write("\\n</details>\\n\\n")
f.write("---\\n\\n") # Separator
logger.info(f"API 调用详情已成功保存到: {file_path}")
return str(file_path)
except IOError as e:
logger.error(f"保存 API 调用详情到文件时发生IO错误 '{file_path}': {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"保存 API 调用详情时发生未知错误 '{file_path}': {e}", exc_info=True)
return None
def get_orchestrator_from_config(config: dict) -> APITestOrchestrator:
"""根据配置字典实例化APITestOrchestrator"""
return APITestOrchestrator(
base_url=config.get('base_url', ''),
custom_test_cases_dir=config.get('custom_test_cases_dir'),
stages_dir=config.get('stages_dir'),
llm_api_key=config.get('llm_api_key'),
llm_base_url=config.get('llm_base_url'),
llm_model_name=config.get('llm_model_name'),
use_llm_for_request_body=config.get('use_llm_for_request_body', False),
use_llm_for_path_params=config.get('use_llm_for_path_params', False),
use_llm_for_query_params=config.get('use_llm_for_query_params', False),
use_llm_for_headers=config.get('use_llm_for_headers', False),
output_dir=config.get('output_dir') # 虽然Orchestrator内部可能不直接用它保存但可以传入
)
# --- API 端点 ---
@app.route('/')
@login_required # Protect the main page
def serve_index():
# Initialize DB if it doesn't exist when first accessing the app
# This is a simple way, for production you might want a separate init step.
if not os.path.exists(DATABASE):
init_db(force_create=True)
return send_from_directory(app.static_folder, 'index.html')
@app.route('/run-tests', methods=['POST'])
@login_required # Protect this endpoint
def run_tests_endpoint():
logger.info("Received request to run tests.")
output_dir = None # To hold the path for report links
temp_spec_path = None # To hold the path for the uploaded spec file
try:
# The form is now sent as multipart/form-data
config_data = request.form.to_dict()
logger.info(f"Received config: {config_data}")
# Handle file upload
if 'api_spec_file' not in request.files:
logger.error("API spec file part is missing from the request.")
return jsonify({"error": "API spec file part is missing"}), 400
file = request.files['api_spec_file']
if file.filename == '':
logger.error("No API spec file selected.")
return jsonify({"error": "No API spec file selected"}), 400
if file:
filename = secure_filename(file.filename)
# Create a unique filename to avoid conflicts
unique_filename = f"{uuid.uuid4()}_{filename}"
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(temp_spec_path)
logger.info(f"Saved uploaded spec file to {temp_spec_path}")
else:
# This case should ideally not be reached if the above checks are in place
return jsonify({"error": "Invalid file object received"}), 400
# Create orchestrator from form data
orchestrator = get_orchestrator_from_config(config_data)
# Prepare summary object
summary = TestSummary()
# Determine API spec type and parse the uploaded file
api_spec_type = config_data.get('api_spec_type', 'YAPI')
logger.info(f"API Spec Type: {api_spec_type}")
parser = InputParser()
parsed_spec = None
if api_spec_type == "YAPI":
parsed_spec = parser.parse_yapi_spec(temp_spec_path)
elif api_spec_type == "Swagger":
parsed_spec = parser.parse_swagger_spec(temp_spec_path)
if not parsed_spec:
error_msg = f"Failed to parse the uploaded {api_spec_type} file."
logger.error(error_msg)
return jsonify({"error": error_msg}), 400
logger.info("Successfully parsed API specification.")
# Execute tests from parsed spec
logger.info(f"Starting test execution from parsed {api_spec_type} spec...")
summary = orchestrator._execute_tests_from_parsed_spec(
parsed_spec=parsed_spec,
summary=summary,
categories=config_data.get('categories').split(',') if config_data.get('categories') else None,
tags=config_data.get('tags').split(',') if config_data.get('tags') else None,
custom_test_cases_dir=config_data.get('custom_test_cases_dir')
)
logger.info("Test case execution finished.")
# Execute test stages
logger.info("Starting stage execution...")
summary = orchestrator.run_stages_from_spec(
parsed_spec=parsed_spec,
summary=summary
)
logger.info("Stage execution finished.")
# Handle output and reporting
output_dir = config_data.get('output_dir', './test_reports')
if not os.path.isabs(output_dir):
output_dir = os.path.join(APP_ROOT, output_dir)
os.makedirs(output_dir, exist_ok=True)
summary_path = os.path.join(output_dir, "test_summary.json")
details_path = os.path.join(output_dir, "api_call_details.md")
with open(summary_path, 'w', encoding='utf-8') as f:
f.write(summary.to_json(pretty=True))
logger.info(f"Test summary saved to {summary_path}")
save_api_call_details_to_file(orchestrator.get_api_call_details(), output_dir)
logger.info(f"API call details saved to {details_path}")
return jsonify({
"summary": summary.to_dict(),
"summary_report_path": f"/download/{os.path.basename(output_dir)}/test_summary.json",
"details_report_path": f"/download/{os.path.basename(output_dir)}/api_call_details.md"
})
except Exception as e:
error_msg = f"An unexpected error occurred during testing: {traceback.format_exc()}"
logger.error(error_msg)
return jsonify({"error": error_msg}), 500
finally:
# Clean up the uploaded file
if temp_spec_path and os.path.exists(temp_spec_path):
os.remove(temp_spec_path)
logger.info(f"Cleaned up temporary file: {temp_spec_path}")
@app.route('/list-yapi-categories', methods=['POST'])
@login_required
def list_yapi_categories_endpoint():
if 'api_spec_file' not in request.files:
return jsonify({"error": "api_spec_file part is missing"}), 400
file = request.files['api_spec_file']
if file.filename == '':
return jsonify({"error": "No file selected"}), 400
temp_spec_path = None
try:
filename = secure_filename(file.filename)
unique_filename = f"{uuid.uuid4()}_{filename}"
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(temp_spec_path)
parser = InputParser()
parsed_yapi = parser.parse_yapi_spec(temp_spec_path)
if not parsed_yapi or not hasattr(parsed_yapi, 'categories') or not parsed_yapi.categories:
return jsonify({"error": "Failed to parse YAPI categories or no categories found"}), 500
categories_list = [
{
"name": cat.get('name', '未命名'),
"description": cat.get('desc') if cat.get('desc') else cat.get('description') if cat.get('description') else '无描述'
}
for cat in parsed_yapi.categories
]
return jsonify(categories_list), 200
except Exception as e:
logger.error(f"Error fetching YAPI categories: {traceback.format_exc()}")
return jsonify({"error": str(e)}), 500
finally:
if temp_spec_path and os.path.exists(temp_spec_path):
os.remove(temp_spec_path)
@app.route('/list-swagger-tags', methods=['POST'])
@login_required
def list_swagger_tags_endpoint():
if 'api_spec_file' not in request.files:
return jsonify({"error": "api_spec_file part is missing"}), 400
file = request.files['api_spec_file']
if file.filename == '':
return jsonify({"error": "No file selected"}), 400
temp_spec_path = None
try:
filename = secure_filename(file.filename)
unique_filename = f"{uuid.uuid4()}_{filename}"
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(temp_spec_path)
parser = InputParser()
parsed_swagger = parser.parse_swagger_spec(temp_spec_path)
if not parsed_swagger or not hasattr(parsed_swagger, 'tags') or not parsed_swagger.tags:
return jsonify({"error": "Failed to parse Swagger tags or no tags found"}), 500
tags_list = [
{"name": tag.get('name', '未命名'), "description": tag.get('description', '无描述')}
for tag in parsed_swagger.tags
]
return jsonify(tags_list), 200
except Exception as e:
logger.error(f"Error fetching Swagger tags: {traceback.format_exc()}")
return jsonify({"error": str(e)}), 500
finally:
if temp_spec_path and os.path.exists(temp_spec_path):
os.remove(temp_spec_path)
@app.route('/download/<path:filepath>')
@login_required
def download_file(filepath):
"""Serve files from a designated reports directory."""
# This is a simplified download endpoint. For production, consider:
# - More robust security checks on the filepath.
# - Configuring the reports directory from a central config.
reports_base_dir = os.path.join(APP_ROOT, 'test_reports')
# Basic security check to prevent directory traversal
if '..' in filepath or os.path.isabs(filepath):
from flask import abort
abort(404)
logger.info(f"Attempting to serve file: {filepath} from directory: {reports_base_dir}")
return send_from_directory(reports_base_dir, filepath, as_attachment=True)
if __name__ == '__main__':
# 注意在生产环境中应使用Gunicorn或uWSGI等WSGI服务器运行Flask应用
# For initial setup, you might need to run the init_db function once.
# You can do this by running flask --app flask_app init-db in your terminal
# or by uncommenting the line below for the very first run:
# init_db(force_create=False)
app.run(debug=True, host='0.0.0.0', port=5050)