569 lines
23 KiB
Python
569 lines
23 KiB
Python
import os
|
||
import sys
|
||
import json
|
||
import logging
|
||
import argparse
|
||
import traceback # 用于更详细的错误日志
|
||
import uuid # For unique filenames
|
||
from pathlib import Path
|
||
import sqlite3 # <-- ADDED: For SQLite database
|
||
from werkzeug.security import generate_password_hash, check_password_hash # <-- ADDED: For password hashing
|
||
from werkzeug.utils import secure_filename # For safe filenames
|
||
from flask import Flask, request, jsonify, send_from_directory, session, redirect, url_for, render_template_string, g # <-- MODIFIED: Added session, redirect, url_for, render_template_string
|
||
from flask_cors import CORS # 用于处理跨域请求
|
||
|
||
# 将ddms_compliance_suite的父目录添加到sys.path
|
||
# 假设flask_app.py与ddms_compliance_suite目录在同一级别,或者ddms_compliance_suite在其PYTHONPATH中
|
||
# 如果 ddms_compliance_suite 是一个已安装的包,则不需要这个
|
||
# current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
# project_root = os.path.dirname(current_dir) # 假设项目根目录是上一级
|
||
# sys.path.insert(0, project_root)
|
||
# 或者更具体地添加包含ddms_compliance_suite的目录
|
||
# sys.path.insert(0, os.path.join(project_root, 'ddms_compliance_suite'))
|
||
|
||
from ddms_compliance_suite.api_caller.caller import APICallDetail
|
||
from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary
|
||
from ddms_compliance_suite.input_parser.parser import InputParser, ParsedYAPISpec, ParsedSwaggerSpec
|
||
# 从 run_api_tests.py 导入辅助函数 (如果它们被重构为可导入的)
|
||
# 为了简单起见,我们可能会直接在 flask_app.py 中重新实现一些逻辑或直接调用Orchestrator
|
||
|
||
app = Flask(__name__, static_folder='static', static_url_path='')
|
||
CORS(app) # 允许所有来源的跨域请求,生产环境中应配置更严格的规则
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 获取 flask_app.py 脚本所在的目录
|
||
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||
UPLOAD_FOLDER = os.path.join(APP_ROOT, 'uploads')
|
||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||
DATABASE = os.path.join(APP_ROOT, 'users.db') # <-- ADDED: Database path
|
||
|
||
app.config['SECRET_KEY'] = os.urandom(24) # <-- ADDED: Secret key for session management
|
||
app.config['DATABASE'] = DATABASE
|
||
|
||
# --- 数据库辅助函数 ---
|
||
def get_db():
|
||
db = getattr(g, '_database', None)
|
||
if db is None:
|
||
db = g._database = sqlite3.connect(DATABASE)
|
||
db.row_factory = sqlite3.Row # Access columns by name
|
||
return db
|
||
|
||
@app.teardown_appcontext
|
||
def close_connection(exception):
|
||
db = getattr(g, '_database', None)
|
||
if db is not None:
|
||
db.close()
|
||
|
||
def init_db(force_create=False):
|
||
"""Initializes the database from schema.sql."""
|
||
if force_create or not os.path.exists(DATABASE):
|
||
with app.app_context():
|
||
db = get_db()
|
||
with app.open_resource('schema.sql', mode='r') as f:
|
||
db.cursor().executescript(f.read())
|
||
db.commit()
|
||
logger.info("Database initialized!")
|
||
else:
|
||
logger.info("Database already exists.")
|
||
|
||
@app.cli.command('init-db')
|
||
def init_db_command():
|
||
"""CLI command to initialize the database."""
|
||
init_db(force_create=True)
|
||
print("Initialized the database.")
|
||
|
||
# --- 用户认证路由 ---
|
||
REGISTER_TEMPLATE = '''
|
||
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head><meta charset="UTF-8"><title>注册</title></head>
|
||
<body>
|
||
<h2>注册新用户</h2>
|
||
{% with messages = get_flashed_messages() %}
|
||
{% if messages %}
|
||
<ul class=flashes>
|
||
{% for message in messages %}
|
||
<li>{{ message }}</li>
|
||
{% endfor %}
|
||
</ul>
|
||
{% endif %}
|
||
{% endwith %}
|
||
<form method="post">
|
||
<label for="username">用户名:</label>
|
||
<input type="text" id="username" name="username" required><br><br>
|
||
<label for="password">密码:</label>
|
||
<input type="password" id="password" name="password" required><br><br>
|
||
<input type="submit" value="注册">
|
||
</form>
|
||
<p>已有账户? <a href="{{ url_for('login') }}">点此登录</a></p>
|
||
</body>
|
||
</html>
|
||
'''
|
||
|
||
LOGIN_TEMPLATE = '''
|
||
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head><meta charset="UTF-8"><title>登录</title></head>
|
||
<body>
|
||
<h2>请登录</h2>
|
||
{% with messages = get_flashed_messages(with_categories=true) %}
|
||
{% if messages %}
|
||
<ul class=flashes>
|
||
{% for category, message in messages %}
|
||
<li class="{{ category }}">{{ message }}</li>
|
||
{% endfor %}
|
||
</ul>
|
||
{% endif %}
|
||
{% endwith %}
|
||
<form method="post">
|
||
<label for="username">用户名:</label>
|
||
<input type="text" id="username" name="username" required><br><br>
|
||
<label for="password">密码:</label>
|
||
<input type="password" id="password" name="password" required><br><br>
|
||
<input type="submit" value="登录">
|
||
</form>
|
||
<p>没有账户? <a href="{{ url_for('register') }}">点此注册</a></p>
|
||
</body>
|
||
</html>
|
||
'''
|
||
|
||
@app.route('/register', methods=('GET', 'POST'))
|
||
def register():
|
||
if request.method == 'POST':
|
||
username = request.form['username']
|
||
password = request.form['password']
|
||
db = get_db()
|
||
error = None
|
||
|
||
if not username:
|
||
error = '用户名是必需的.'
|
||
elif not password:
|
||
error = '密码是必需的.'
|
||
|
||
if error is None:
|
||
try:
|
||
db.execute(
|
||
"INSERT INTO user (username, password_hash) VALUES (?, ?)",
|
||
(username, generate_password_hash(password)),
|
||
)
|
||
db.commit()
|
||
except db.IntegrityError: # Username already exists
|
||
error = f"用户 {username} 已被注册."
|
||
else:
|
||
flash('注册成功! 请登录.')
|
||
return redirect(url_for("login"))
|
||
|
||
flash(error)
|
||
return render_template_string(REGISTER_TEMPLATE)
|
||
|
||
@app.route('/login', methods=('GET', 'POST'))
|
||
def login():
|
||
if request.method == 'POST':
|
||
username = request.form['username']
|
||
password = request.form['password']
|
||
db = get_db()
|
||
error = None
|
||
user = db.execute(
|
||
'SELECT * FROM user WHERE username = ?',
|
||
(username,)
|
||
).fetchone()
|
||
|
||
if user is None:
|
||
error = '用户名不存在.'
|
||
elif not check_password_hash(user['password_hash'], password):
|
||
error = '密码错误.'
|
||
|
||
if error is None:
|
||
session.clear()
|
||
session['user_id'] = user['id']
|
||
session['username'] = user['username']
|
||
flash('登录成功!', 'success')
|
||
return redirect(url_for('serve_index'))
|
||
|
||
flash(error, 'error')
|
||
|
||
# If user is already logged in, redirect to index
|
||
if 'user_id' in session:
|
||
return redirect(url_for('serve_index'))
|
||
|
||
return render_template_string(LOGIN_TEMPLATE)
|
||
|
||
@app.route('/logout')
|
||
def logout():
|
||
session.clear()
|
||
flash('您已成功登出.')
|
||
return redirect(url_for('login'))
|
||
|
||
# --- 应用辅助函数和路由保护 ---
|
||
from functools import wraps
|
||
from flask import g, flash # ensure g and flash are imported
|
||
|
||
def login_required(view):
|
||
@wraps(view)
|
||
def wrapped_view(**kwargs):
|
||
if g.user is None:
|
||
return redirect(url_for('login'))
|
||
return view(**kwargs)
|
||
return wrapped_view
|
||
|
||
@app.before_request
|
||
def load_logged_in_user():
|
||
user_id = session.get('user_id')
|
||
if user_id is None:
|
||
g.user = None
|
||
else:
|
||
g.user = get_db().execute(
|
||
'SELECT * FROM user WHERE id = ?', (user_id,)
|
||
).fetchone()
|
||
|
||
# --- 辅助函数 ---
|
||
def save_api_call_details_to_file(api_call_details: list[APICallDetail], output_dir_path_str: str, filename: str = "api_call_details.md"):
|
||
"""将API调用详情保存到Markdown文件。"""
|
||
if not api_call_details:
|
||
logger.info("没有API调用详情可保存。")
|
||
return None
|
||
|
||
output_dir = Path(output_dir_path_str)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
file_path = output_dir / filename
|
||
|
||
unique_id_counter = 0
|
||
|
||
try:
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
f.write("# API 调用详情记录\\n\\n")
|
||
for detail_obj in api_call_details:
|
||
unique_id_counter +=1
|
||
unique_id = f"api-call-{unique_id_counter}"
|
||
|
||
# Convert elapsed time from seconds to milliseconds for display
|
||
elapsed_ms = detail_obj.response_elapsed_time * 1000
|
||
|
||
f.write(f"<details id='{unique_id}'>\\n")
|
||
f.write(f"<summary><b>{detail_obj.request_method.upper()}</b> {detail_obj.request_url} - <b>状态: {detail_obj.response_status_code}</b> - 耗时: {elapsed_ms:.2f}ms</summary>\\n\\n")
|
||
|
||
# 请求详情
|
||
f.write("#### 请求 (Request)\\n")
|
||
f.write(f"- **Method:** `{detail_obj.request_method.upper()}`\\n")
|
||
f.write(f"- **URL:** `{detail_obj.request_url}`\\n")
|
||
if detail_obj.request_headers:
|
||
f.write("- **Headers:**\\n")
|
||
f.write("```json\\n")
|
||
f.write(json.dumps(detail_obj.request_headers, indent=2, ensure_ascii=False) + "\\n")
|
||
f.write("```\\n")
|
||
if detail_obj.request_params:
|
||
f.write("- **Query Parameters:**\\n")
|
||
f.write("```json\\n")
|
||
f.write(json.dumps(detail_obj.request_params, indent=2, ensure_ascii=False) + "\\n")
|
||
f.write("```\\n")
|
||
if detail_obj.request_body:
|
||
f.write("- **Request Body:**\\n")
|
||
f.write("```json\\n") # 假设请求体是JSON,如果不是则可能需要调整
|
||
# 尝试解析为JSON,如果失败则按原样写入
|
||
try:
|
||
body_json = json.loads(detail_obj.request_body) if isinstance(detail_obj.request_body, str) else detail_obj.request_body
|
||
f.write(json.dumps(body_json, indent=2, ensure_ascii=False) + "\\n")
|
||
except json.JSONDecodeError:
|
||
f.write(str(detail_obj.request_body) + "\\n") # Fallback to string
|
||
f.write("```\\n")
|
||
|
||
f.write(f"- **cURL Command:**\\n")
|
||
f.write("```bash\\n")
|
||
f.write(detail_obj.curl_command + "\\n")
|
||
f.write("```\\n\\n")
|
||
|
||
# 响应详情
|
||
f.write("#### 响应 (Response)\\n")
|
||
f.write(f"- **Status Code:** `{detail_obj.response_status_code}`\\n")
|
||
f.write(f"- **Elapsed Time:** {elapsed_ms:.2f} ms\\n")
|
||
if detail_obj.response_headers:
|
||
f.write("- **Response Headers:**\\n")
|
||
f.write("```json\\n")
|
||
f.write(json.dumps(detail_obj.response_headers, indent=2, ensure_ascii=False) + "\\n")
|
||
f.write("```\\n")
|
||
if detail_obj.response_body:
|
||
f.write("- **Response Body:**\\n")
|
||
# 首先尝试格式化为JSON,如果失败则保持原样
|
||
try:
|
||
# 假设 response_body 是字符串或者可以被json.loads处理的字节串
|
||
body_to_write = detail_obj.response_body
|
||
if isinstance(body_to_write, bytes):
|
||
try:
|
||
body_to_write = body_to_write.decode('utf-8')
|
||
except UnicodeDecodeError:
|
||
body_to_write = str(body_to_write) # Fallback if not UTF-8
|
||
|
||
if isinstance(body_to_write, str):
|
||
try:
|
||
parsed_json = json.loads(body_to_write)
|
||
f.write("```json\\n")
|
||
f.write(json.dumps(parsed_json, indent=2, ensure_ascii=False) + "\\n")
|
||
f.write("```\\n")
|
||
except json.JSONDecodeError: # Not a JSON string
|
||
f.write("```text\\n") # Treat as plain text
|
||
f.write(body_to_write + "\\n")
|
||
f.write("```\\n")
|
||
else: # Already a dict/list (shouldn't happen if APICallDetail.response_body is str/bytes)
|
||
f.write("```json\\n")
|
||
f.write(json.dumps(body_to_write, indent=2, ensure_ascii=False) + "\\n")
|
||
f.write("```\\n")
|
||
except Exception as e_resp_body:
|
||
logger.error(f"Error processing response body for API call to {detail_obj.request_url}: {e_resp_body}")
|
||
f.write("```text\\n")
|
||
f.write(f"(Error processing response body: {e_resp_body})\\n")
|
||
f.write(str(detail_obj.response_body) + "\\n") # Fallback
|
||
f.write("```\\n")
|
||
else:
|
||
f.write("- Response Body: (empty)\\n")
|
||
|
||
f.write("\\n</details>\\n\\n")
|
||
f.write("---\\n\\n") # Separator
|
||
|
||
logger.info(f"API 调用详情已成功保存到: {file_path}")
|
||
return str(file_path)
|
||
except IOError as e:
|
||
logger.error(f"保存 API 调用详情到文件时发生IO错误 '{file_path}': {e}", exc_info=True)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"保存 API 调用详情时发生未知错误 '{file_path}': {e}", exc_info=True)
|
||
return None
|
||
|
||
def get_orchestrator_from_config(config: dict) -> APITestOrchestrator:
|
||
"""根据配置字典实例化APITestOrchestrator"""
|
||
return APITestOrchestrator(
|
||
base_url=config.get('base_url', ''),
|
||
custom_test_cases_dir=config.get('custom_test_cases_dir'),
|
||
stages_dir=config.get('stages_dir'),
|
||
llm_api_key=config.get('llm_api_key'),
|
||
llm_base_url=config.get('llm_base_url'),
|
||
llm_model_name=config.get('llm_model_name'),
|
||
use_llm_for_request_body=config.get('use_llm_for_request_body', False),
|
||
use_llm_for_path_params=config.get('use_llm_for_path_params', False),
|
||
use_llm_for_query_params=config.get('use_llm_for_query_params', False),
|
||
use_llm_for_headers=config.get('use_llm_for_headers', False),
|
||
output_dir=config.get('output_dir') # 虽然Orchestrator内部可能不直接用它保存,但可以传入
|
||
)
|
||
|
||
# --- API 端点 ---
|
||
@app.route('/')
|
||
@login_required # Protect the main page
|
||
def serve_index():
|
||
# Initialize DB if it doesn't exist when first accessing the app
|
||
# This is a simple way, for production you might want a separate init step.
|
||
if not os.path.exists(DATABASE):
|
||
init_db(force_create=True)
|
||
return send_from_directory(app.static_folder, 'index.html')
|
||
|
||
@app.route('/run-tests', methods=['POST'])
|
||
@login_required # Protect this endpoint
|
||
def run_tests_endpoint():
|
||
logger.info("Received request to run tests.")
|
||
|
||
output_dir = None # To hold the path for report links
|
||
temp_spec_path = None # To hold the path for the uploaded spec file
|
||
|
||
try:
|
||
# The form is now sent as multipart/form-data
|
||
config_data = request.form.to_dict()
|
||
logger.info(f"Received config: {config_data}")
|
||
|
||
# Handle file upload
|
||
if 'api_spec_file' not in request.files:
|
||
logger.error("API spec file part is missing from the request.")
|
||
return jsonify({"error": "API spec file part is missing"}), 400
|
||
|
||
file = request.files['api_spec_file']
|
||
if file.filename == '':
|
||
logger.error("No API spec file selected.")
|
||
return jsonify({"error": "No API spec file selected"}), 400
|
||
|
||
if file:
|
||
filename = secure_filename(file.filename)
|
||
# Create a unique filename to avoid conflicts
|
||
unique_filename = f"{uuid.uuid4()}_{filename}"
|
||
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
|
||
file.save(temp_spec_path)
|
||
logger.info(f"Saved uploaded spec file to {temp_spec_path}")
|
||
else:
|
||
# This case should ideally not be reached if the above checks are in place
|
||
return jsonify({"error": "Invalid file object received"}), 400
|
||
|
||
# Create orchestrator from form data
|
||
orchestrator = get_orchestrator_from_config(config_data)
|
||
|
||
# Prepare summary object
|
||
summary = TestSummary()
|
||
|
||
# Determine API spec type and parse the uploaded file
|
||
api_spec_type = config_data.get('api_spec_type', 'YAPI')
|
||
logger.info(f"API Spec Type: {api_spec_type}")
|
||
|
||
parser = InputParser()
|
||
parsed_spec = None
|
||
if api_spec_type == "YAPI":
|
||
parsed_spec = parser.parse_yapi_spec(temp_spec_path)
|
||
elif api_spec_type == "Swagger":
|
||
parsed_spec = parser.parse_swagger_spec(temp_spec_path)
|
||
|
||
if not parsed_spec:
|
||
error_msg = f"Failed to parse the uploaded {api_spec_type} file."
|
||
logger.error(error_msg)
|
||
return jsonify({"error": error_msg}), 400
|
||
|
||
logger.info("Successfully parsed API specification.")
|
||
|
||
# Execute tests from parsed spec
|
||
logger.info(f"Starting test execution from parsed {api_spec_type} spec...")
|
||
summary = orchestrator._execute_tests_from_parsed_spec(
|
||
parsed_spec=parsed_spec,
|
||
summary=summary,
|
||
categories=config_data.get('categories').split(',') if config_data.get('categories') else None,
|
||
tags=config_data.get('tags').split(',') if config_data.get('tags') else None,
|
||
custom_test_cases_dir=config_data.get('custom_test_cases_dir')
|
||
)
|
||
logger.info("Test case execution finished.")
|
||
|
||
# Execute test stages
|
||
logger.info("Starting stage execution...")
|
||
summary = orchestrator.run_stages_from_spec(
|
||
parsed_spec=parsed_spec,
|
||
summary=summary
|
||
)
|
||
logger.info("Stage execution finished.")
|
||
|
||
# Handle output and reporting
|
||
output_dir = config_data.get('output_dir', './test_reports')
|
||
if not os.path.isabs(output_dir):
|
||
output_dir = os.path.join(APP_ROOT, output_dir)
|
||
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
summary_path = os.path.join(output_dir, "test_summary.json")
|
||
details_path = os.path.join(output_dir, "api_call_details.md")
|
||
|
||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||
f.write(summary.to_json(pretty=True))
|
||
logger.info(f"Test summary saved to {summary_path}")
|
||
|
||
save_api_call_details_to_file(orchestrator.get_api_call_details(), output_dir)
|
||
logger.info(f"API call details saved to {details_path}")
|
||
|
||
return jsonify({
|
||
"summary": summary.to_dict(),
|
||
"summary_report_path": f"/download/{os.path.basename(output_dir)}/test_summary.json",
|
||
"details_report_path": f"/download/{os.path.basename(output_dir)}/api_call_details.md"
|
||
})
|
||
|
||
except Exception as e:
|
||
error_msg = f"An unexpected error occurred during testing: {traceback.format_exc()}"
|
||
logger.error(error_msg)
|
||
return jsonify({"error": error_msg}), 500
|
||
finally:
|
||
# Clean up the uploaded file
|
||
if temp_spec_path and os.path.exists(temp_spec_path):
|
||
os.remove(temp_spec_path)
|
||
logger.info(f"Cleaned up temporary file: {temp_spec_path}")
|
||
|
||
@app.route('/list-yapi-categories', methods=['POST'])
|
||
@login_required
|
||
def list_yapi_categories_endpoint():
|
||
if 'api_spec_file' not in request.files:
|
||
return jsonify({"error": "api_spec_file part is missing"}), 400
|
||
|
||
file = request.files['api_spec_file']
|
||
if file.filename == '':
|
||
return jsonify({"error": "No file selected"}), 400
|
||
|
||
temp_spec_path = None
|
||
try:
|
||
filename = secure_filename(file.filename)
|
||
unique_filename = f"{uuid.uuid4()}_{filename}"
|
||
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
|
||
file.save(temp_spec_path)
|
||
|
||
parser = InputParser()
|
||
parsed_yapi = parser.parse_yapi_spec(temp_spec_path)
|
||
|
||
if not parsed_yapi or not hasattr(parsed_yapi, 'categories') or not parsed_yapi.categories:
|
||
return jsonify({"error": "Failed to parse YAPI categories or no categories found"}), 500
|
||
|
||
categories_list = [
|
||
{
|
||
"name": cat.get('name', '未命名'),
|
||
"description": cat.get('desc') if cat.get('desc') else cat.get('description') if cat.get('description') else '无描述'
|
||
}
|
||
for cat in parsed_yapi.categories
|
||
]
|
||
return jsonify(categories_list), 200
|
||
except Exception as e:
|
||
logger.error(f"Error fetching YAPI categories: {traceback.format_exc()}")
|
||
return jsonify({"error": str(e)}), 500
|
||
finally:
|
||
if temp_spec_path and os.path.exists(temp_spec_path):
|
||
os.remove(temp_spec_path)
|
||
|
||
@app.route('/list-swagger-tags', methods=['POST'])
|
||
@login_required
|
||
def list_swagger_tags_endpoint():
|
||
if 'api_spec_file' not in request.files:
|
||
return jsonify({"error": "api_spec_file part is missing"}), 400
|
||
|
||
file = request.files['api_spec_file']
|
||
if file.filename == '':
|
||
return jsonify({"error": "No file selected"}), 400
|
||
|
||
temp_spec_path = None
|
||
try:
|
||
filename = secure_filename(file.filename)
|
||
unique_filename = f"{uuid.uuid4()}_{filename}"
|
||
temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
|
||
file.save(temp_spec_path)
|
||
|
||
parser = InputParser()
|
||
parsed_swagger = parser.parse_swagger_spec(temp_spec_path)
|
||
|
||
if not parsed_swagger or not hasattr(parsed_swagger, 'tags') or not parsed_swagger.tags:
|
||
return jsonify({"error": "Failed to parse Swagger tags or no tags found"}), 500
|
||
|
||
tags_list = [
|
||
{"name": tag.get('name', '未命名'), "description": tag.get('description', '无描述')}
|
||
for tag in parsed_swagger.tags
|
||
]
|
||
return jsonify(tags_list), 200
|
||
except Exception as e:
|
||
logger.error(f"Error fetching Swagger tags: {traceback.format_exc()}")
|
||
return jsonify({"error": str(e)}), 500
|
||
finally:
|
||
if temp_spec_path and os.path.exists(temp_spec_path):
|
||
os.remove(temp_spec_path)
|
||
|
||
@app.route('/download/<path:filepath>')
|
||
@login_required
|
||
def download_file(filepath):
|
||
"""Serve files from a designated reports directory."""
|
||
# This is a simplified download endpoint. For production, consider:
|
||
# - More robust security checks on the filepath.
|
||
# - Configuring the reports directory from a central config.
|
||
reports_base_dir = os.path.join(APP_ROOT, 'test_reports')
|
||
|
||
# Basic security check to prevent directory traversal
|
||
if '..' in filepath or os.path.isabs(filepath):
|
||
from flask import abort
|
||
abort(404)
|
||
|
||
logger.info(f"Attempting to serve file: {filepath} from directory: {reports_base_dir}")
|
||
return send_from_directory(reports_base_dir, filepath, as_attachment=True)
|
||
|
||
if __name__ == '__main__':
|
||
# 注意:在生产环境中,应使用Gunicorn或uWSGI等WSGI服务器运行Flask应用
|
||
# For initial setup, you might need to run the init_db function once.
|
||
# You can do this by running flask --app flask_app init-db in your terminal
|
||
# or by uncommenting the line below for the very first run:
|
||
# init_db(force_create=False)
|
||
app.run(debug=True, host='0.0.0.0', port=5050) |