import os import sys import json import logging import argparse import traceback # 用于更详细的错误日志 import uuid # For unique filenames from pathlib import Path import sqlite3 # <-- ADDED: For SQLite database from werkzeug.security import generate_password_hash, check_password_hash # <-- ADDED: For password hashing from werkzeug.utils import secure_filename # For safe filenames from flask import Flask, request, jsonify, send_from_directory, session, redirect, url_for, render_template_string, g, flash # <-- MODIFIED: Added flash from flask_cors import CORS # 用于处理跨域请求 from urllib.parse import urlparse # 将ddms_compliance_suite的父目录添加到sys.path # 假设flask_app.py与ddms_compliance_suite目录在同一级别,或者ddms_compliance_suite在其PYTHONPATH中 # 如果 ddms_compliance_suite 是一个已安装的包,则不需要这个 # current_dir = os.path.dirname(os.path.abspath(__file__)) # project_root = os.path.dirname(current_dir) # 假设项目根目录是上一级 # sys.path.insert(0, project_root) # 或者更具体地添加包含ddms_compliance_suite的目录 # sys.path.insert(0, os.path.join(project_root, 'ddms_compliance_suite')) from ddms_compliance_suite.api_caller.caller import APICallDetail from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary from ddms_compliance_suite.input_parser.parser import InputParser, ParsedYAPISpec, ParsedSwaggerSpec # 从 run_api_tests.py 导入辅助函数 (如果它们被重构为可导入的) # 为了简单起见,我们可能会直接在 flask_app.py 中重新实现一些逻辑或直接调用Orchestrator app = Flask(__name__, static_folder='static', static_url_path='') CORS(app) # 允许所有来源的跨域请求,生产环境中应配置更严格的规则 # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 获取 flask_app.py 脚本所在的目录 APP_ROOT = os.path.dirname(os.path.abspath(__file__)) # 如果是打包后的环境,可能需要使用sys._MEIPASS if hasattr(sys, '_MEIPASS'): APP_ROOT = sys._MEIPASS UPLOAD_FOLDER = os.path.join(APP_ROOT, 'uploads') os.makedirs(UPLOAD_FOLDER, exist_ok=True) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER DATABASE = os.path.join(APP_ROOT, 'users.db') # <-- ADDED: Database path app.config['SECRET_KEY'] = os.urandom(24) # <-- ADDED: Secret key for session management app.config['DATABASE'] = DATABASE # 数据库Schema定义,直接嵌入到代码中 DB_SCHEMA = ''' DROP TABLE IF EXISTS user; CREATE TABLE user ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL ); ''' # --- 数据库辅助函数 --- def get_db(): db = getattr(g, '_database', None) if db is None: db = g._database = sqlite3.connect(DATABASE) db.row_factory = sqlite3.Row # Access columns by name return db @app.teardown_appcontext def close_connection(exception): db = getattr(g, '_database', None) if db is not None: db.close() def init_db(force_create=False): """Initializes the database from schema.sql.""" if force_create or not os.path.exists(DATABASE): with app.app_context(): db = get_db() # 直接使用内置的DB_SCHEMA db.cursor().executescript(DB_SCHEMA) db.commit() logger.info("Database initialized!") else: logger.info("Database already exists.") def create_default_user(username="admin", password="admin123"): """创建默认用户,如果不存在""" with app.app_context(): db = get_db() # 检查用户是否已存在 user = db.execute('SELECT * FROM user WHERE username = ?', (username,)).fetchone() if user is None: # 创建默认用户 db.execute( "INSERT INTO user (username, password_hash) VALUES (?, ?)", (username, generate_password_hash(password)) ) db.commit() logger.info(f"已创建默认用户: {username}") else: logger.info(f"默认用户 {username} 已存在") @app.cli.command('init-db') def init_db_command(): """CLI command to initialize the database.""" init_db(force_create=True) print("Initialized the database.") # --- 用户认证路由 --- # 注册模板已禁用 """ REGISTER_TEMPLATE = ''' 注册

注册新用户

{% with messages = get_flashed_messages() %} {% if messages %} {% endif %} {% endwith %}




已有账户? 点此登录

''' """ LOGIN_TEMPLATE = ''' 登录

请登录

{% with messages = get_flashed_messages(with_categories=true) %} {% if messages %} {% endif %} {% endwith %}




''' # 注册功能已暂时禁用 # @app.route('/register', methods=('GET', 'POST')) # def register(): # if request.method == 'POST': # username = request.form['username'] # password = request.form['password'] # db = get_db() # error = None # # if not username: # error = '用户名是必需的.' # elif not password: # error = '密码是必需的.' # # if error is None: # try: # db.execute( # "INSERT INTO user (username, password_hash) VALUES (?, ?)", # (username, generate_password_hash(password)), # ) # db.commit() # except db.IntegrityError: # Username already exists # error = f"用户 {username} 已被注册." # else: # flash('注册成功! 请登录.') # return redirect(url_for("login")) # # flash(error) # return render_template_string(REGISTER_TEMPLATE) @app.route('/login', methods=('GET', 'POST')) def login(): if request.method == 'POST': username = request.form['username'] password = request.form['password'] db = get_db() error = None user = db.execute( 'SELECT * FROM user WHERE username = ?', (username,) ).fetchone() if user is None: error = '用户名不存在.' elif not check_password_hash(user['password_hash'], password): error = '密码错误.' if error is None: session.clear() session['user_id'] = user['id'] session['username'] = user['username'] flash('登录成功!', 'success') return redirect(url_for('serve_index')) flash(error, 'error') # If user is already logged in, redirect to index if 'user_id' in session: return redirect(url_for('serve_index')) return render_template_string(LOGIN_TEMPLATE) @app.route('/logout') def logout(): session.clear() flash('您已成功登出.') return redirect(url_for('login')) # --- 应用辅助函数和路由保护 --- from functools import wraps from flask import g, flash # ensure g and flash are imported def login_required(view): @wraps(view) def wrapped_view(**kwargs): if g.user is None: return redirect(url_for('login')) return view(**kwargs) return wrapped_view @app.before_request def load_logged_in_user(): user_id = session.get('user_id') if user_id is None: g.user = None else: g.user = get_db().execute( 'SELECT * FROM user WHERE id = ?', (user_id,) ).fetchone() # --- 辅助函数 --- def save_api_call_details_to_file(api_call_details: list[APICallDetail], output_dir_path_str: str, filename: str = "api_call_details.md"): """将API调用详情保存到Markdown文件。""" if not api_call_details: logger.info("没有API调用详情可保存。") return None output_dir = Path(output_dir_path_str) output_dir.mkdir(parents=True, exist_ok=True) file_path = output_dir / filename unique_id_counter = 0 try: with open(file_path, 'w', encoding='utf-8') as f: f.write("# API 调用详情记录\n\n") for detail_obj in api_call_details: unique_id_counter +=1 unique_id = f"api-call-{unique_id_counter}" # Convert elapsed time from seconds to milliseconds for display elapsed_ms = detail_obj.response_elapsed_time * 1000 f.write(f"
\n") f.write(f"{detail_obj.request_method.upper()} {detail_obj.request_url} - 状态: {detail_obj.response_status_code} - 耗时: {elapsed_ms:.2f}ms\n\n") # 提取URL中的接口名称 parsed_url = urlparse(detail_obj.request_url) path_parts = parsed_url.path.split('/') # 获取路径的最后一部分作为接口名称 api_name = parsed_url.path # 如果是版本号,则使用前一个部分 # if api_name.replace('.', '').isdigit(): # api_name = f"{path_parts[-2]}/{api_name}" # 添加三级标题显示接口名称 f.write(f"### 接口: {api_name}\n\n") # 请求详情 f.write("#### 请求 (Request)\n") f.write(f"- **Method:** `{detail_obj.request_method.upper()}`\n") f.write(f"- **URL:** `{detail_obj.request_url}`\n") if detail_obj.request_headers: f.write("- **Headers:**\n") f.write("```json\n") f.write(json.dumps(detail_obj.request_headers, indent=2, ensure_ascii=False) + "\n") f.write("```\n") if detail_obj.request_params: f.write("- **Query Parameters:**\n") f.write("```json\n") f.write(json.dumps(detail_obj.request_params, indent=2, ensure_ascii=False) + "\n") f.write("```\n") if detail_obj.request_body: f.write("- **Request Body:**\n") f.write("```json\n") # 假设请求体是JSON,如果不是则可能需要调整 # 尝试解析为JSON,如果失败则按原样写入 try: body_json = json.loads(detail_obj.request_body) if isinstance(detail_obj.request_body, str) else detail_obj.request_body f.write(json.dumps(body_json, indent=2, ensure_ascii=False) + "\n") except json.JSONDecodeError: f.write(str(detail_obj.request_body) + "\n") # Fallback to string f.write("```\n") f.write(f"- **cURL Command:**\n") f.write("```bash\n") f.write(detail_obj.curl_command + "\n") f.write("```\n\n") # 响应详情 f.write("#### 响应 (Response)\n") f.write(f"- **Status Code:** `{detail_obj.response_status_code}`\n") f.write(f"- **Elapsed Time:** {elapsed_ms:.2f} ms\n") if detail_obj.response_headers: f.write("- **Response Headers:**\n") f.write("```json\n") f.write(json.dumps(detail_obj.response_headers, indent=2, ensure_ascii=False) + "\n") f.write("```\n") if detail_obj.response_body: f.write("- **Response Body:**\n") # 首先尝试格式化为JSON,如果失败则保持原样 try: # 假设 response_body 是字符串或者可以被json.loads处理的字节串 body_to_write = detail_obj.response_body if isinstance(body_to_write, bytes): try: body_to_write = body_to_write.decode('utf-8') except UnicodeDecodeError: body_to_write = str(body_to_write) # Fallback if not UTF-8 if isinstance(body_to_write, str): try: parsed_json = json.loads(body_to_write) f.write("```json\n") f.write(json.dumps(parsed_json, indent=2, ensure_ascii=False) + "\n") f.write("```\n") except json.JSONDecodeError: # Not a JSON string f.write("```text\n") # Treat as plain text f.write(body_to_write + "\n") f.write("```\n") else: # Already a dict/list (shouldn't happen if APICallDetail.response_body is str/bytes) f.write("```json\n") f.write(json.dumps(body_to_write, indent=2, ensure_ascii=False) + "\n") f.write("```\n") except Exception as e_resp_body: logger.error(f"Error processing response body for API call to {detail_obj.request_url}: {e_resp_body}") f.write("```text\n") f.write(f"(Error processing response body: {e_resp_body})\n") f.write(str(detail_obj.response_body) + "\n") # Fallback f.write("```\n") else: f.write("- Response Body: (empty)\n") f.write("\n
\n\n") f.write("---\n\n") # Separator logger.info(f"API 调用详情已成功保存到: {file_path}") return str(file_path) except IOError as e: logger.error(f"保存 API 调用详情到文件时发生IO错误 '{file_path}': {e}", exc_info=True) return None except Exception as e: logger.error(f"保存 API 调用详情时发生未知错误 '{file_path}': {e}", exc_info=True) return None def get_orchestrator_from_config(config: dict) -> APITestOrchestrator: """根据配置字典实例化APITestOrchestrator""" return APITestOrchestrator( base_url=config.get('base_url', ''), custom_test_cases_dir=config.get('custom_test_cases_dir'), stages_dir=config.get('stages_dir'), llm_api_key=config.get('llm_api_key'), llm_base_url=config.get('llm_base_url'), llm_model_name=config.get('llm_model_name'), use_llm_for_request_body=config.get('use_llm_for_request_body', False), use_llm_for_path_params=config.get('use_llm_for_path_params', False), use_llm_for_query_params=config.get('use_llm_for_query_params', False), use_llm_for_headers=config.get('use_llm_for_headers', False), output_dir=config.get('output_dir') # 虽然Orchestrator内部可能不直接用它保存,但可以传入 ) # --- API 端点 --- @app.route('/') @login_required # Protect the main page def serve_index(): # Initialize DB if it doesn't exist when first accessing the app # This is a simple way, for production you might want a separate init step. if not os.path.exists(DATABASE): logger.info("数据库不存在,正在初始化...") init_db(force_create=True) logger.info("数据库初始化完成") # 创建默认用户 create_default_user() else: # 检查数据库中是否有user表 try: db = get_db() db.execute("SELECT 1 FROM user LIMIT 1") logger.info("数据库user表存在") except sqlite3.OperationalError: logger.info("数据库中没有user表,正在重新初始化...") init_db(force_create=True) logger.info("数据库初始化完成") # 创建默认用户 create_default_user() return send_from_directory(app.static_folder, 'index.html') @app.route('/run-tests', methods=['POST']) @login_required # Protect this endpoint def run_tests_endpoint(): logger.info("Received request to run tests.") output_dir = None # To hold the path for report links temp_spec_path = None # To hold the path for the uploaded spec file try: # The form is now sent as multipart/form-data config_data = request.form.to_dict() logger.info(f"Received config: {config_data}") # Handle file upload if 'api_spec_file' not in request.files: logger.error("API spec file part is missing from the request.") return jsonify({"error": "API spec file part is missing"}), 400 file = request.files['api_spec_file'] if file.filename == '': logger.error("No API spec file selected.") return jsonify({"error": "No API spec file selected"}), 400 if file: filename = secure_filename(file.filename) # Create a unique filename to avoid conflicts unique_filename = f"{uuid.uuid4()}_{filename}" temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(temp_spec_path) logger.info(f"Saved uploaded spec file to {temp_spec_path}") else: # This case should ideally not be reached if the above checks are in place return jsonify({"error": "Invalid file object received"}), 400 # Create orchestrator from form data orchestrator = get_orchestrator_from_config(config_data) # Prepare summary object summary = TestSummary() # Determine API spec type and parse the uploaded file api_spec_type = config_data.get('api_spec_type', 'YAPI') logger.info(f"API Spec Type: {api_spec_type}") parser = InputParser() parsed_spec = None if api_spec_type == "YAPI": parsed_spec = parser.parse_yapi_spec(temp_spec_path) elif api_spec_type == "Swagger": parsed_spec = parser.parse_swagger_spec(temp_spec_path) if not parsed_spec: error_msg = f"Failed to parse the uploaded {api_spec_type} file." logger.error(error_msg) return jsonify({"error": error_msg}), 400 logger.info("Successfully parsed API specification.") # Execute tests from parsed spec logger.info(f"Starting test execution from parsed {api_spec_type} spec...") summary = orchestrator._execute_tests_from_parsed_spec( parsed_spec=parsed_spec, summary=summary, categories=config_data.get('categories').split(',') if config_data.get('categories') else None, tags=config_data.get('tags').split(',') if config_data.get('tags') else None, custom_test_cases_dir=config_data.get('custom_test_cases_dir') ) logger.info("Test case execution finished.") # Execute test stages logger.info("Starting stage execution...") summary = orchestrator.run_stages_from_spec( parsed_spec=parsed_spec, summary=summary ) logger.info("Stage execution finished.") # Handle output and reporting output_dir = config_data.get('output_dir', './test_reports') if not os.path.isabs(output_dir): output_dir = os.path.join(APP_ROOT, output_dir) os.makedirs(output_dir, exist_ok=True) summary_path = os.path.join(output_dir, "test_summary.json") details_path = os.path.join(output_dir, "api_call_details.md") with open(summary_path, 'w', encoding='utf-8') as f: f.write(summary.to_json(pretty=True)) logger.info(f"Test summary saved to {summary_path}") save_api_call_details_to_file(orchestrator.get_api_call_details(), output_dir) logger.info(f"API call details saved to {details_path}") return jsonify({ "summary": summary.to_dict(), "summary_report_path": f"/download/test_summary.json", "details_report_path": f"/download/api_call_details.md" }) except Exception as e: error_msg = f"An unexpected error occurred during testing: {traceback.format_exc()}" logger.error(error_msg) return jsonify({"error": error_msg}), 500 finally: # Clean up the uploaded file if temp_spec_path and os.path.exists(temp_spec_path): os.remove(temp_spec_path) logger.info(f"Cleaned up temporary file: {temp_spec_path}") @app.route('/list-yapi-categories', methods=['POST']) @login_required def list_yapi_categories_endpoint(): if 'api_spec_file' not in request.files: return jsonify({"error": "api_spec_file part is missing"}), 400 file = request.files['api_spec_file'] if file.filename == '': return jsonify({"error": "No file selected"}), 400 temp_spec_path = None try: filename = secure_filename(file.filename) unique_filename = f"{uuid.uuid4()}_{filename}" temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(temp_spec_path) parser = InputParser() parsed_yapi = parser.parse_yapi_spec(temp_spec_path) if not parsed_yapi or not hasattr(parsed_yapi, 'categories') or not parsed_yapi.categories: return jsonify({"error": "Failed to parse YAPI categories or no categories found"}), 500 categories_list = [ { "name": cat.get('name', '未命名'), "description": cat.get('desc') if cat.get('desc') else cat.get('description') if cat.get('description') else '无描述' } for cat in parsed_yapi.categories ] return jsonify(categories_list), 200 except Exception as e: logger.error(f"Error fetching YAPI categories: {traceback.format_exc()}") return jsonify({"error": str(e)}), 500 finally: if temp_spec_path and os.path.exists(temp_spec_path): os.remove(temp_spec_path) @app.route('/list-swagger-tags', methods=['POST']) @login_required def list_swagger_tags_endpoint(): if 'api_spec_file' not in request.files: return jsonify({"error": "api_spec_file part is missing"}), 400 file = request.files['api_spec_file'] if file.filename == '': return jsonify({"error": "No file selected"}), 400 temp_spec_path = None try: filename = secure_filename(file.filename) unique_filename = f"{uuid.uuid4()}_{filename}" temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(temp_spec_path) parser = InputParser() parsed_swagger = parser.parse_swagger_spec(temp_spec_path) if not parsed_swagger or not hasattr(parsed_swagger, 'tags') or not parsed_swagger.tags: return jsonify({"error": "Failed to parse Swagger tags or no tags found"}), 500 tags_list = [ {"name": tag.get('name', '未命名'), "description": tag.get('description', '无描述')} for tag in parsed_swagger.tags ] return jsonify(tags_list), 200 except Exception as e: logger.error(f"Error fetching Swagger tags: {traceback.format_exc()}") return jsonify({"error": str(e)}), 500 finally: if temp_spec_path and os.path.exists(temp_spec_path): os.remove(temp_spec_path) @app.route('/download/') @login_required def download_file(filepath): """Serve files from a designated reports directory.""" # This is a simplified download endpoint. For production, consider: # - More robust security checks on the filepath. # - Configuring the reports directory from a central config. reports_base_dir = os.path.join(APP_ROOT, 'test_reports') # Basic security check to prevent directory traversal if '..' in filepath or os.path.isabs(filepath): from flask import abort abort(404) logger.info(f"Attempting to serve file: {filepath} from directory: {reports_base_dir}") return send_from_directory(reports_base_dir, filepath, as_attachment=True) # 资源文件路径处理函数 def resource_path(relative_path): """获取资源的绝对路径,适用于开发环境和打包环境""" try: # PyInstaller创建临时文件夹,将路径存储在_MEIPASS中 base_path = sys._MEIPASS except Exception: base_path = os.path.abspath(".") return os.path.join(base_path, relative_path) if __name__ == '__main__': # 注意:在生产环境中,应使用Gunicorn或uWSGI等WSGI服务器运行Flask应用 # For initial setup, you might need to run the init_db function once. # You can do this by running flask --app flask_app init-db in your terminal # or by uncommenting the line below for the very first run: # 打包说明: # 使用PyInstaller打包时,只需包含static目录 # 数据库schema已内置到代码中,不再需要外部schema.sql文件 # 命令示例:pyinstaller --add-data "static:static" flask_app.py # 检查并初始化数据库 if not os.path.exists(DATABASE): logger.info("数据库不存在,正在初始化...") init_db(force_create=True) logger.info("数据库初始化完成") # 创建默认用户 create_default_user() # 使用SSL证书和密钥启用HTTPS # cert_path = os.path.join(APP_ROOT, 'ssl/cert.pem') # key_path = os.path.join(APP_ROOT, 'ssl/key.pem') # ssl_context = (cert_path, key_path) # app.run(debug=True, host='0.0.0.0', port=8443, ssl_context=ssl_context) app.run(debug=False, host='0.0.0.0', port=5050)