import os import sys import json import logging import argparse import traceback # 用于更详细的错误日志 import uuid # For unique filenames from pathlib import Path import sqlite3 # <-- ADDED: For SQLite database from werkzeug.security import generate_password_hash, check_password_hash # <-- ADDED: For password hashing from werkzeug.utils import secure_filename # For safe filenames from flask import Flask, request, jsonify, send_from_directory, session, redirect, url_for, render_template_string, g # <-- MODIFIED: Added session, redirect, url_for, render_template_string from flask_cors import CORS # 用于处理跨域请求 # 将ddms_compliance_suite的父目录添加到sys.path # 假设flask_app.py与ddms_compliance_suite目录在同一级别,或者ddms_compliance_suite在其PYTHONPATH中 # 如果 ddms_compliance_suite 是一个已安装的包,则不需要这个 # current_dir = os.path.dirname(os.path.abspath(__file__)) # project_root = os.path.dirname(current_dir) # 假设项目根目录是上一级 # sys.path.insert(0, project_root) # 或者更具体地添加包含ddms_compliance_suite的目录 # sys.path.insert(0, os.path.join(project_root, 'ddms_compliance_suite')) from ddms_compliance_suite.api_caller.caller import APICallDetail from ddms_compliance_suite.test_orchestrator import APITestOrchestrator, TestSummary from ddms_compliance_suite.input_parser.parser import InputParser, ParsedYAPISpec, ParsedSwaggerSpec # 从 run_api_tests.py 导入辅助函数 (如果它们被重构为可导入的) # 为了简单起见,我们可能会直接在 flask_app.py 中重新实现一些逻辑或直接调用Orchestrator app = Flask(__name__, static_folder='static', static_url_path='') CORS(app) # 允许所有来源的跨域请求,生产环境中应配置更严格的规则 # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 获取 flask_app.py 脚本所在的目录 APP_ROOT = os.path.dirname(os.path.abspath(__file__)) UPLOAD_FOLDER = os.path.join(APP_ROOT, 'uploads') os.makedirs(UPLOAD_FOLDER, exist_ok=True) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER DATABASE = os.path.join(APP_ROOT, 'users.db') # <-- ADDED: Database path app.config['SECRET_KEY'] = os.urandom(24) # <-- ADDED: Secret key for session management app.config['DATABASE'] = DATABASE # --- 数据库辅助函数 --- def get_db(): db = getattr(g, '_database', None) if db is None: db = g._database = sqlite3.connect(DATABASE) db.row_factory = sqlite3.Row # Access columns by name return db @app.teardown_appcontext def close_connection(exception): db = getattr(g, '_database', None) if db is not None: db.close() def init_db(force_create=False): """Initializes the database from schema.sql.""" if force_create or not os.path.exists(DATABASE): with app.app_context(): db = get_db() with app.open_resource('schema.sql', mode='r') as f: db.cursor().executescript(f.read()) db.commit() logger.info("Database initialized!") else: logger.info("Database already exists.") @app.cli.command('init-db') def init_db_command(): """CLI command to initialize the database.""" init_db(force_create=True) print("Initialized the database.") # --- 用户认证路由 --- REGISTER_TEMPLATE = ''' 注册

注册新用户

{% with messages = get_flashed_messages() %} {% if messages %} {% endif %} {% endwith %}




已有账户? 点此登录

''' LOGIN_TEMPLATE = ''' 登录

请登录

{% with messages = get_flashed_messages(with_categories=true) %} {% if messages %} {% endif %} {% endwith %}




没有账户? 点此注册

''' @app.route('/register', methods=('GET', 'POST')) def register(): if request.method == 'POST': username = request.form['username'] password = request.form['password'] db = get_db() error = None if not username: error = '用户名是必需的.' elif not password: error = '密码是必需的.' if error is None: try: db.execute( "INSERT INTO user (username, password_hash) VALUES (?, ?)", (username, generate_password_hash(password)), ) db.commit() except db.IntegrityError: # Username already exists error = f"用户 {username} 已被注册." else: flash('注册成功! 请登录.') return redirect(url_for("login")) flash(error) return render_template_string(REGISTER_TEMPLATE) @app.route('/login', methods=('GET', 'POST')) def login(): if request.method == 'POST': username = request.form['username'] password = request.form['password'] db = get_db() error = None user = db.execute( 'SELECT * FROM user WHERE username = ?', (username,) ).fetchone() if user is None: error = '用户名不存在.' elif not check_password_hash(user['password_hash'], password): error = '密码错误.' if error is None: session.clear() session['user_id'] = user['id'] session['username'] = user['username'] flash('登录成功!', 'success') return redirect(url_for('serve_index')) flash(error, 'error') # If user is already logged in, redirect to index if 'user_id' in session: return redirect(url_for('serve_index')) return render_template_string(LOGIN_TEMPLATE) @app.route('/logout') def logout(): session.clear() flash('您已成功登出.') return redirect(url_for('login')) # --- 应用辅助函数和路由保护 --- from functools import wraps from flask import g, flash # ensure g and flash are imported def login_required(view): @wraps(view) def wrapped_view(**kwargs): if g.user is None: return redirect(url_for('login')) return view(**kwargs) return wrapped_view @app.before_request def load_logged_in_user(): user_id = session.get('user_id') if user_id is None: g.user = None else: g.user = get_db().execute( 'SELECT * FROM user WHERE id = ?', (user_id,) ).fetchone() # --- 辅助函数 --- def save_api_call_details_to_file(api_call_details: list[APICallDetail], output_dir_path_str: str, filename: str = "api_call_details.md"): """将API调用详情保存到Markdown文件。""" if not api_call_details: logger.info("没有API调用详情可保存。") return None output_dir = Path(output_dir_path_str) output_dir.mkdir(parents=True, exist_ok=True) file_path = output_dir / filename unique_id_counter = 0 try: with open(file_path, 'w', encoding='utf-8') as f: f.write("# API 调用详情记录\\n\\n") for detail_obj in api_call_details: unique_id_counter +=1 unique_id = f"api-call-{unique_id_counter}" # Convert elapsed time from seconds to milliseconds for display elapsed_ms = detail_obj.response_elapsed_time * 1000 f.write(f"
\\n") f.write(f"{detail_obj.request_method.upper()} {detail_obj.request_url} - 状态: {detail_obj.response_status_code} - 耗时: {elapsed_ms:.2f}ms\\n\\n") # 请求详情 f.write("#### 请求 (Request)\\n") f.write(f"- **Method:** `{detail_obj.request_method.upper()}`\\n") f.write(f"- **URL:** `{detail_obj.request_url}`\\n") if detail_obj.request_headers: f.write("- **Headers:**\\n") f.write("```json\\n") f.write(json.dumps(detail_obj.request_headers, indent=2, ensure_ascii=False) + "\\n") f.write("```\\n") if detail_obj.request_params: f.write("- **Query Parameters:**\\n") f.write("```json\\n") f.write(json.dumps(detail_obj.request_params, indent=2, ensure_ascii=False) + "\\n") f.write("```\\n") if detail_obj.request_body: f.write("- **Request Body:**\\n") f.write("```json\\n") # 假设请求体是JSON,如果不是则可能需要调整 # 尝试解析为JSON,如果失败则按原样写入 try: body_json = json.loads(detail_obj.request_body) if isinstance(detail_obj.request_body, str) else detail_obj.request_body f.write(json.dumps(body_json, indent=2, ensure_ascii=False) + "\\n") except json.JSONDecodeError: f.write(str(detail_obj.request_body) + "\\n") # Fallback to string f.write("```\\n") f.write(f"- **cURL Command:**\\n") f.write("```bash\\n") f.write(detail_obj.curl_command + "\\n") f.write("```\\n\\n") # 响应详情 f.write("#### 响应 (Response)\\n") f.write(f"- **Status Code:** `{detail_obj.response_status_code}`\\n") f.write(f"- **Elapsed Time:** {elapsed_ms:.2f} ms\\n") if detail_obj.response_headers: f.write("- **Response Headers:**\\n") f.write("```json\\n") f.write(json.dumps(detail_obj.response_headers, indent=2, ensure_ascii=False) + "\\n") f.write("```\\n") if detail_obj.response_body: f.write("- **Response Body:**\\n") # 首先尝试格式化为JSON,如果失败则保持原样 try: # 假设 response_body 是字符串或者可以被json.loads处理的字节串 body_to_write = detail_obj.response_body if isinstance(body_to_write, bytes): try: body_to_write = body_to_write.decode('utf-8') except UnicodeDecodeError: body_to_write = str(body_to_write) # Fallback if not UTF-8 if isinstance(body_to_write, str): try: parsed_json = json.loads(body_to_write) f.write("```json\\n") f.write(json.dumps(parsed_json, indent=2, ensure_ascii=False) + "\\n") f.write("```\\n") except json.JSONDecodeError: # Not a JSON string f.write("```text\\n") # Treat as plain text f.write(body_to_write + "\\n") f.write("```\\n") else: # Already a dict/list (shouldn't happen if APICallDetail.response_body is str/bytes) f.write("```json\\n") f.write(json.dumps(body_to_write, indent=2, ensure_ascii=False) + "\\n") f.write("```\\n") except Exception as e_resp_body: logger.error(f"Error processing response body for API call to {detail_obj.request_url}: {e_resp_body}") f.write("```text\\n") f.write(f"(Error processing response body: {e_resp_body})\\n") f.write(str(detail_obj.response_body) + "\\n") # Fallback f.write("```\\n") else: f.write("- Response Body: (empty)\\n") f.write("\\n
\\n\\n") f.write("---\\n\\n") # Separator logger.info(f"API 调用详情已成功保存到: {file_path}") return str(file_path) except IOError as e: logger.error(f"保存 API 调用详情到文件时发生IO错误 '{file_path}': {e}", exc_info=True) return None except Exception as e: logger.error(f"保存 API 调用详情时发生未知错误 '{file_path}': {e}", exc_info=True) return None def get_orchestrator_from_config(config: dict) -> APITestOrchestrator: """根据配置字典实例化APITestOrchestrator""" return APITestOrchestrator( base_url=config.get('base_url', ''), custom_test_cases_dir=config.get('custom_test_cases_dir'), stages_dir=config.get('stages_dir'), llm_api_key=config.get('llm_api_key'), llm_base_url=config.get('llm_base_url'), llm_model_name=config.get('llm_model_name'), use_llm_for_request_body=config.get('use_llm_for_request_body', False), use_llm_for_path_params=config.get('use_llm_for_path_params', False), use_llm_for_query_params=config.get('use_llm_for_query_params', False), use_llm_for_headers=config.get('use_llm_for_headers', False), output_dir=config.get('output_dir') # 虽然Orchestrator内部可能不直接用它保存,但可以传入 ) # --- API 端点 --- @app.route('/') @login_required # Protect the main page def serve_index(): # Initialize DB if it doesn't exist when first accessing the app # This is a simple way, for production you might want a separate init step. if not os.path.exists(DATABASE): init_db(force_create=True) return send_from_directory(app.static_folder, 'index.html') @app.route('/run-tests', methods=['POST']) @login_required # Protect this endpoint def run_tests_endpoint(): logger.info("Received request to run tests.") output_dir = None # To hold the path for report links temp_spec_path = None # To hold the path for the uploaded spec file try: # The form is now sent as multipart/form-data config_data = request.form.to_dict() logger.info(f"Received config: {config_data}") # Handle file upload if 'api_spec_file' not in request.files: logger.error("API spec file part is missing from the request.") return jsonify({"error": "API spec file part is missing"}), 400 file = request.files['api_spec_file'] if file.filename == '': logger.error("No API spec file selected.") return jsonify({"error": "No API spec file selected"}), 400 if file: filename = secure_filename(file.filename) # Create a unique filename to avoid conflicts unique_filename = f"{uuid.uuid4()}_{filename}" temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(temp_spec_path) logger.info(f"Saved uploaded spec file to {temp_spec_path}") else: # This case should ideally not be reached if the above checks are in place return jsonify({"error": "Invalid file object received"}), 400 # Create orchestrator from form data orchestrator = get_orchestrator_from_config(config_data) # Prepare summary object summary = TestSummary() # Determine API spec type and parse the uploaded file api_spec_type = config_data.get('api_spec_type', 'YAPI') logger.info(f"API Spec Type: {api_spec_type}") parser = InputParser() parsed_spec = None if api_spec_type == "YAPI": parsed_spec = parser.parse_yapi_spec(temp_spec_path) elif api_spec_type == "Swagger": parsed_spec = parser.parse_swagger_spec(temp_spec_path) if not parsed_spec: error_msg = f"Failed to parse the uploaded {api_spec_type} file." logger.error(error_msg) return jsonify({"error": error_msg}), 400 logger.info("Successfully parsed API specification.") # Execute tests from parsed spec logger.info(f"Starting test execution from parsed {api_spec_type} spec...") summary = orchestrator._execute_tests_from_parsed_spec( parsed_spec=parsed_spec, summary=summary, categories=config_data.get('categories').split(',') if config_data.get('categories') else None, tags=config_data.get('tags').split(',') if config_data.get('tags') else None, custom_test_cases_dir=config_data.get('custom_test_cases_dir') ) logger.info("Test case execution finished.") # Execute test stages logger.info("Starting stage execution...") summary = orchestrator.run_stages_from_spec( parsed_spec=parsed_spec, summary=summary ) logger.info("Stage execution finished.") # Handle output and reporting output_dir = config_data.get('output_dir', './test_reports') if not os.path.isabs(output_dir): output_dir = os.path.join(APP_ROOT, output_dir) os.makedirs(output_dir, exist_ok=True) summary_path = os.path.join(output_dir, "test_summary.json") details_path = os.path.join(output_dir, "api_call_details.md") with open(summary_path, 'w', encoding='utf-8') as f: f.write(summary.to_json(pretty=True)) logger.info(f"Test summary saved to {summary_path}") save_api_call_details_to_file(orchestrator.get_api_call_details(), output_dir) logger.info(f"API call details saved to {details_path}") return jsonify({ "summary": summary.to_dict(), "summary_report_path": f"/download/{os.path.basename(output_dir)}/test_summary.json", "details_report_path": f"/download/{os.path.basename(output_dir)}/api_call_details.md" }) except Exception as e: error_msg = f"An unexpected error occurred during testing: {traceback.format_exc()}" logger.error(error_msg) return jsonify({"error": error_msg}), 500 finally: # Clean up the uploaded file if temp_spec_path and os.path.exists(temp_spec_path): os.remove(temp_spec_path) logger.info(f"Cleaned up temporary file: {temp_spec_path}") @app.route('/list-yapi-categories', methods=['POST']) @login_required def list_yapi_categories_endpoint(): if 'api_spec_file' not in request.files: return jsonify({"error": "api_spec_file part is missing"}), 400 file = request.files['api_spec_file'] if file.filename == '': return jsonify({"error": "No file selected"}), 400 temp_spec_path = None try: filename = secure_filename(file.filename) unique_filename = f"{uuid.uuid4()}_{filename}" temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(temp_spec_path) parser = InputParser() parsed_yapi = parser.parse_yapi_spec(temp_spec_path) if not parsed_yapi or not hasattr(parsed_yapi, 'categories') or not parsed_yapi.categories: return jsonify({"error": "Failed to parse YAPI categories or no categories found"}), 500 categories_list = [ { "name": cat.get('name', '未命名'), "description": cat.get('desc') if cat.get('desc') else cat.get('description') if cat.get('description') else '无描述' } for cat in parsed_yapi.categories ] return jsonify(categories_list), 200 except Exception as e: logger.error(f"Error fetching YAPI categories: {traceback.format_exc()}") return jsonify({"error": str(e)}), 500 finally: if temp_spec_path and os.path.exists(temp_spec_path): os.remove(temp_spec_path) @app.route('/list-swagger-tags', methods=['POST']) @login_required def list_swagger_tags_endpoint(): if 'api_spec_file' not in request.files: return jsonify({"error": "api_spec_file part is missing"}), 400 file = request.files['api_spec_file'] if file.filename == '': return jsonify({"error": "No file selected"}), 400 temp_spec_path = None try: filename = secure_filename(file.filename) unique_filename = f"{uuid.uuid4()}_{filename}" temp_spec_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(temp_spec_path) parser = InputParser() parsed_swagger = parser.parse_swagger_spec(temp_spec_path) if not parsed_swagger or not hasattr(parsed_swagger, 'tags') or not parsed_swagger.tags: return jsonify({"error": "Failed to parse Swagger tags or no tags found"}), 500 tags_list = [ {"name": tag.get('name', '未命名'), "description": tag.get('description', '无描述')} for tag in parsed_swagger.tags ] return jsonify(tags_list), 200 except Exception as e: logger.error(f"Error fetching Swagger tags: {traceback.format_exc()}") return jsonify({"error": str(e)}), 500 finally: if temp_spec_path and os.path.exists(temp_spec_path): os.remove(temp_spec_path) @app.route('/download/') @login_required def download_file(filepath): """Serve files from a designated reports directory.""" # This is a simplified download endpoint. For production, consider: # - More robust security checks on the filepath. # - Configuring the reports directory from a central config. reports_base_dir = os.path.join(APP_ROOT, 'test_reports') # Basic security check to prevent directory traversal if '..' in filepath or os.path.isabs(filepath): from flask import abort abort(404) logger.info(f"Attempting to serve file: {filepath} from directory: {reports_base_dir}") return send_from_directory(reports_base_dir, filepath, as_attachment=True) if __name__ == '__main__': # 注意:在生产环境中,应使用Gunicorn或uWSGI等WSGI服务器运行Flask应用 # For initial setup, you might need to run the init_db function once. # You can do this by running flask --app flask_app init-db in your terminal # or by uncommenting the line below for the very first run: # init_db(force_create=False) app.run(debug=True, host='0.0.0.0', port=5050)