import mne import os import json import glob import matplotlib.pyplot as plt import sys import numpy as np import math # For ceil from collections import Counter # For counting annotations from tqdm import tqdm # <<< Import tqdm import argparse import matplotlib.ticker as ticker import matplotlib.transforms as transforms import warnings # --- 用户配置 --- ROOT_INPUT_DIR = "/data/shuiyuan/workplace/EEG/isip.piconepress.com/data/tuh_eeg/tuh_eeg_events/v2.0.1/edf/" ROOT_OUTPUT_DIR = "/data1/gongwenxin/workspace/eeg/dataset/tuh/eeg_annotation_plots_filtered_5s_resample/" OUTPUT_SUBDIR_NO_HIGHLIGHT_ALL = "no_highlight/all_channels_batched" # 唯一保留的图像输出子目录 FOLDER_JSON_FILENAME = "_folder_summary.json" GLOBAL_JSON_FILENAME = "global_summary_metadata.json" # --- 新增:分批绘图参数 --- CHANNELS_PER_BATCH_IMAGE = 5 # 每张"所有通道"图像包含的通道数 # --- 绘图参数 (保持) --- PLOT_BUFFER_SECONDS = 2 MIN_PLOT_DURATION = 1.0 FIG_WIDTH_INCHES = 12 FIG_HEIGHT_PER_CHANNEL = 2 # 保持较大的高度以减少标签重叠 DPI = 100 LINE_COLOR = 'black' LINE_WIDTH = 0.5 ZERO_LINE_COLOR = 'gray' ZERO_LINE_STYLE = '-' ZERO_LINE_WIDTH = 0.5 VOLTAGE_LABEL_PRECISION = 0 Y_SCALING_MODE = 'fixed' FIXED_Y_SCALE_UV = 150 Y_PADDING_FACTOR = 0.1 # --- TCP Montage 定义 (保持) --- anodes = ['EEG FP1-REF', 'EEG F7-REF', 'EEG T3-REF', 'EEG T5-REF', 'EEG FP2-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG T6-REF', 'EEG A1-REF', 'EEG T3-REF', 'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 'EEG FP1-REF', 'EEG F3-REF', 'EEG C3-REF', 'EEG P3-REF', 'EEG FP2-REF', 'EEG F4-REF', 'EEG C4-REF', 'EEG P4-REF'] cathodes = ['EEG F7-REF', 'EEG T3-REF', 'EEG T5-REF', 'EEG O1-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG T6-REF', 'EEG O2-REF', 'EEG T3-REF', 'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 'EEG A2-REF', 'EEG F3-REF', 'EEG C3-REF', 'EEG P3-REF', 'EEG O1-REF', 'EEG F4-REF', 'EEG C4-REF', 'EEG P4-REF', 'EEG O2-REF'] ch_names_tcp = ['FP1-F7', 'F7-T3', 'T3-T5', 'T5-O1', 'FP2-F8', 'F8-T4', 'T4-T6', 'T6-O2', 'A1-T3', 'T3-C3', 'C3-CZ', 'CZ-C4', 'C4-T4', 'T4-A2', 'FP1-F3', 'F3-C3', 'C3-P3', 'P3-O1', 'FP2-F4', 'F4-C4', 'C4-P4', 'P4-O2'] # REC 文件标签代码到描述的映射 (保持) rec_label_map = { 1: 'spsw', 2: 'gped', 3: 'pled', 4: 'eyem', 5: 'artf', 6: 'bckg' } all_possible_labels = list(rec_label_map.values()) def plot_eeg_matplotlib_flexible_scale( data, times, ch_names, title, save_path, y_scaling_mode='fixed', fixed_y_scale_uv=100, y_padding_factor=0.1, fig_width=16, fig_height_per_channel=0.6, dpi=100, line_color='black', line_width=0.5, zero_line_color='dimgray', zero_line_style='-', zero_line_width=0.6, guide_line_color='lightgray', guide_line_style='-', guide_line_width=0.5, show_channel_voltage_labels=True, channel_voltage_label_color='gray', voltage_label_precision=0, x_tick_interval=0.1 ): """ 绘制EEG数据,所有通道在同一Axes上,垂直偏移。 每个通道绘制基线("0 uV"标签)、辅助水平线,并可选地在其顶部/底部辅助线上添加电压标签。 警告: 当 show_channel_voltage_labels=True 时,如果通道密集,标签可能重叠。 Args: # ... (previous args) ... zero_line_color (str): 每个通道偏移基线 (0 uV线) 的颜色。 zero_line_style (str): 每个通道偏移基线的样式。 zero_line_width (float): 每个通道偏移基线的宽度。 guide_line_color (str): 辅助水平刻度线的颜色。 guide_line_style (str): 辅助水平刻度线的样式。 guide_line_width (float): 辅助水平刻度线的宽度。 show_channel_voltage_labels (bool): 是否在每个通道的顶部/底部辅助线旁添加电压标签。 channel_voltage_label_color (str): 通道电压标签 (+/- 值) 的颜色。 # ... (remaining args) ... Returns: bool: True if successful, False otherwise. """ n_channels = data.shape[0] if n_channels == 0: warnings.warn("数据中没有通道,无法绘图。") return False total_fig_height = max(n_channels * fig_height_per_channel + 1.5, 4.0) fig, ax = plt.subplots(1, 1, figsize=(fig_width, total_fig_height), dpi=dpi) plot_start_time = times[0] plot_end_time = times[-1] actual_plot_duration = plot_end_time - plot_start_time data_uv = data * 1e6 # --- Y轴缩放和偏移计算 --- channel_display_scale_uv = abs(fixed_y_scale_uv) if y_scaling_mode == 'global_auto' or y_scaling_mode == 'individual_auto': if y_scaling_mode == 'individual_auto': warnings.warn("'individual_auto' scaling is not suitable for stacked plots. Using 'global_auto' behavior.") if data_uv.size > 0: global_max_abs = np.max(np.abs(data_uv)) if global_max_abs > 1e-6: channel_display_scale_uv = global_max_abs * (1 + y_padding_factor) # else: use fixed_y_scale_uv default # else: use fixed_y_scale_uv default if channel_display_scale_uv <= 0: warnings.warn(f"计算出的 channel_display_scale_uv ({channel_display_scale_uv:.2f}) 无效,将使用默认值 100uV。") channel_display_scale_uv = 100.0 # offset_step = channel_display_scale_uv * 2.5 fixed_gap_uv = 50.0 offset_step = 2 * channel_display_scale_uv + fixed_gap_uv y_offsets = [(n_channels - 1 - i) * offset_step for i in range(n_channels)] guide_levels_uv = [ # Levels relative to baseline offset channel_display_scale_uv, channel_display_scale_uv * 2/3, channel_display_scale_uv * 1/3, -channel_display_scale_uv * 1/3, -channel_display_scale_uv * 2/3, -channel_display_scale_uv ] label_format = f"{{:+.{voltage_label_precision}f}} μV" # Format with sign # --- 绘图循环 --- for i in range(n_channels): ch_name = ch_names[i] channel_data = data_uv[i] offset = y_offsets[i] # 1. 信号 ax.plot(times, channel_data + offset, color=line_color, linewidth=line_width, zorder=3) # 2. 基线 (0 uV) ax.axhline(offset, color=zero_line_color, linestyle=zero_line_style, linewidth=zero_line_width, zorder=2, alpha=0.9) trans = transforms.blended_transform_factory(ax.transAxes, ax.transData) ax.text(1.001, offset, " 0 μV", transform=trans, va='center', ha='left', fontsize=6, color=zero_line_color, zorder=2) # 3. 辅助刻度线 for level_uv in guide_levels_uv: ax.axhline(offset + level_uv, color=guide_line_color, linestyle=guide_line_style, linewidth=guide_line_width, alpha=0.7, zorder=1) scale_bar_uv_val = channel_display_scale_uv if scale_bar_uv_val <=0: scale_bar_uv_val = abs(fixed_y_scale_uv) if scale_bar_uv_val <=0: scale_bar_uv_val = 100.0 scale_bar_x_pos = plot_end_time - 0.02 * actual_plot_duration scale_bar_y_center = ax.get_ylim()[1] - offset_step * 0.5 scale_bar_y_start = scale_bar_y_center - scale_bar_uv_val / 2.0 scale_bar_y_end = scale_bar_y_center + scale_bar_uv_val / 2.0 # 4. 可选:在顶部/底部辅助线旁添加电压标签 if show_channel_voltage_labels: # Top label (+scale uV) y_pos_top = offset + channel_display_scale_uv ax.text( 1.001 , y_pos_top, # Position similar to '0 uV' label label_format.format(channel_display_scale_uv), # Format with '+' sign transform=trans, va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2) y_pos_top31 = offset + channel_display_scale_uv*1/3 ax.text( 1.001, y_pos_top31, # Position similar to '0 uV' label label_format.format(channel_display_scale_uv*1/3), # Format with '+' sign transform=trans, va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2) y_pos_top32 = offset + channel_display_scale_uv*2/3 ax.text( 1.001, y_pos_top32, # Position similar to '0 uV' label label_format.format(channel_display_scale_uv*2/3), # Format with '+' sign transform=trans, va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2) # Bottom label (-scale uV) y_pos_bottom = offset - channel_display_scale_uv ax.text( 1.001 , y_pos_bottom, # Position similar to '0 uV' label label_format.format(-channel_display_scale_uv), # Format with '-' sign transform=trans, va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2) y_pos_topf31 = offset -channel_display_scale_uv*1/3 ax.text( 1.001, y_pos_topf31, # Position similar to '0 uV' label label_format.format(-channel_display_scale_uv*1/3), # Format with '+' sign transform=trans, va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2) y_pos_topf32 = offset -channel_display_scale_uv*2/3 ax.text( 1.001, y_pos_topf32, # Position similar to '0 uV' label label_format.format(-channel_display_scale_uv*2/3), # Format with '+' sign transform=trans, va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2) # --- X轴和网格线 --- ax.set_xlim(plot_start_time, plot_end_time) ax.set_xlabel("Time (s)", fontsize=9) ax.tick_params(axis='x', labelsize=7) if actual_plot_duration > 0 and x_tick_interval is not None and x_tick_interval > 0: # --- Major Grid Lines --- # Set major ticks location (e.g., every 1 second if x_tick_interval is 0.1) major_locator = ticker.MultipleLocator(base=x_tick_interval * 10) ax.xaxis.set_major_locator(major_locator) # Enable and style the major grid ax.xaxis.grid(True, which='major', linestyle='-', linewidth=0.5, color='darkgrey', alpha=0.7, zorder=0) # --- Minor Grid Lines --- # Set minor ticks location (e.g., every 0.1 seconds if x_tick_interval is 0.1) minor_locator = ticker.MultipleLocator(base=x_tick_interval) ax.xaxis.set_minor_locator(minor_locator) # Enable and style the minor grid (use a different style) ax.xaxis.grid(True, which='minor', linestyle='-', linewidth=0.5, color='lightgray', alpha=0.7, zorder=0) # Example style # Ensure Y-axis grid remains off ax.yaxis.grid(False) else: # Disable both X and Y grids if condition not met ax.xaxis.grid(False) ax.yaxis.grid(False) # --- Y轴配置 --- ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False); ax.spines['bottom'].set_visible(True) ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=True) ax.set_yticks(y_offsets) ax.set_yticklabels(ch_names, fontsize=8, va='center') y_min_limit = y_offsets[-1] + min(guide_levels_uv) - channel_display_scale_uv * 0.2 y_max_limit = y_offsets[0] + max(guide_levels_uv) + channel_display_scale_uv * 0.2 if y_max_limit <= y_min_limit: y_max_limit = y_min_limit + offset_step ax.set_ylim(y_min_limit, y_max_limit) # --- 标题和布局 --- fig.suptitle(title, fontsize=12) # Potentially need more left margin for the added voltage labels plt.tight_layout() # Increased left margin slightly more # --- 保存图像 --- try: os.makedirs(os.path.dirname(save_path), exist_ok=True) fig.savefig(save_path) plt.close(fig) return True except Exception as e: print(f" 错误: 保存图像到 '{save_path}' 时失败: {e}") plt.close(fig) return False # --- 单个文件夹处理函数 --- def process_single_folder(input_folder, output_folder_base, args): """ 处理单个文件夹中的 EEG 文件和注释, 生成图像和元数据。 对于 'all_channels' 类型的图像,会按 CHANNELS_PER_BATCH_IMAGE 分批生成。 """ folder_name = os.path.basename(input_folder) # 更新输出子目录名称 (可选) output_dirs = { "no_hl_all": os.path.join(output_folder_base, OUTPUT_SUBDIR_NO_HIGHLIGHT_ALL), } # 创建目录 () if not args.disable_no_hl_all: os.makedirs(output_dirs["no_hl_all"], exist_ok=True) folder_metadata_list = [] folder_stats = Counter({label: 0 for label in all_possible_labels}) processed_files_count = 0 generated_images_count = 0 # 总图像文件数 edf_files = sorted(glob.glob(os.path.join(input_folder, '*.edf'))) if not edf_files: return [], folder_stats, 0, 0 for edf_path in tqdm(edf_files, desc=f"Processing {folder_name}", unit="file", leave=False): base_filename = os.path.basename(edf_path)[:-4] rec_path = os.path.join(input_folder, base_filename + '.rec') if not os.path.exists(rec_path): tqdm.write(f" Warning: REC file not found for '{base_filename}.edf' in {folder_name}. Skipping pair.") continue # --- 1. 加载和预处理 EEG 数据 () --- try: # ... (加载和预处理逻辑) ... with mne.utils.use_log_level('WARNING'): raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False) print("采样频率:", raw.info['sfreq'], "Hz") # 打印采样频率 print("信息:", raw.info) # 打印采样频率 raw.filter(l_freq=0.1, h_freq=75.0) raw.notch_filter(50.0) # Notch filter for power line noise raw.resample(100, n_jobs=5) raw.rename_channels(lambda s: s.replace('EEG ', '').replace('-REF', '').replace('.', '').upper(), verbose=False) anodes_std = [ch.replace('EEG ', '').replace('-REF', '').replace('.', '').upper() for ch in anodes] cathodes_std = [ch.replace('EEG ', '').replace('-REF', '').replace('.', '').upper() for ch in cathodes] missing_channels = set(anodes_std + cathodes_std) - set(raw.ch_names) if missing_channels: tqdm.write(f" Error: EDF {base_filename} missing required TCP channels: {missing_channels}. Skipping.") continue with mne.utils.use_log_level('ERROR'): try: raw.set_montage(mne.channels.make_standard_montage('standard_1020'), on_missing='warn', verbose=False) except Exception: pass # Ignore if montage fails raw_tcp = mne.set_bipolar_reference(raw, anode=anodes_std, cathode=cathodes_std, ch_name=ch_names_tcp, copy=True, verbose=False) raw_tcp.pick_channels(ch_names_tcp) # Ensure only TCP channels remain except Exception as e: tqdm.write(f" Error: Failed to load/preprocess EDF '{base_filename}' in {folder_name}: {e}. Skipping.") continue # --- 2. 解析整个 REC 文件 () --- all_annotations_in_file = [] current_file_labels = [] try: # ... (REC 解析逻辑) ... with open(rec_path, 'r') as f_rec: # print(rec_path) for line_num, line in enumerate(f_rec): line = line.strip(); parts = line.split(',') if not line or len(parts) != 4: continue try: ch_idx = int(parts[0].strip()); start = float(parts[1].strip()); stop = float(parts[2].strip()); lbl_code = int(parts[3].strip()) duration = stop - start # Validate channel index against *current* raw_tcp channels (should be ch_names_tcp) if not (0 <= ch_idx < len(raw_tcp.ch_names)) or duration <= 0 or lbl_code not in rec_label_map: continue label_desc = rec_label_map[lbl_code] # Use the actual channel name from raw_tcp after picking actual_ch_name = raw_tcp.ch_names[ch_idx] all_annotations_in_file.append({ "ch_index": ch_idx, # Index within the final TCP montage "ch_name": actual_ch_name, # Actual name in raw_tcp "start": start, "stop": stop, "duration": duration, "label_code": lbl_code, "label": label_desc, "rec_line": line_num + 1 }) current_file_labels.append(label_desc) except ValueError: continue # Ignore lines with non-numeric values where expected if not all_annotations_in_file: processed_files_count += 1 # Still count as processed even if no annotations continue # Skip image generation if no annotations folder_stats.update(current_file_labels) except Exception as e: tqdm.write(f" Error: Failed to read/parse REC '{base_filename}.rec' in {folder_name}: {e}. Skipping file pair.") continue # --- 3. 遍历每个标注作为"触发器"来生成图像 --- for trigger_annot in all_annotations_in_file: start_time = trigger_annot['start']; stop_time = trigger_annot['stop']; duration = trigger_annot['duration'] annotated_channel_name = trigger_annot['ch_name']; label_description = trigger_annot['label'] # Use the index directly from the annotation (it's already relative to raw_tcp) ch_index = trigger_annot['ch_index'] rec_line = trigger_annot['rec_line'] # --- 计算绘图时间窗口和截取数据 () --- plot_duration = max(duration + 2 * PLOT_BUFFER_SECONDS, MIN_PLOT_DURATION) plot_center = start_time + duration / 2.0 plot_start = max(0, plot_center - plot_duration / 2.0) # Ensure plot_end doesn't exceed data length plot_end = min(raw_tcp.times[-1], plot_center + plot_duration / 2.0) actual_plot_duration = plot_end - plot_start if actual_plot_duration < 0.1: # Skip very short or invalid intervals tqdm.write(f" Skipping annotation at {start_time:.2f}s: plot duration too short ({actual_plot_duration:.3f}s).") continue try: with mne.utils.use_log_level('ERROR'): # Crop the already processed raw_tcp object cropped_raw = raw_tcp.copy().crop(tmin=plot_start, tmax=plot_end, include_tmax=True, verbose=False) data_cropped, times_cropped = cropped_raw.get_data(return_times=True) if data_cropped.size == 0 or times_cropped.size == 0: tqdm.write(f" Skipping annotation at {start_time:.2f}s: Cropped data is empty.") continue except Exception as crop_err: tqdm.write(f" Error cropping data for annotation at {start_time:.2f}s: {crop_err}. Skipping annotation.") continue # --- 准备基础文件名 (不含批次信息) --- safe_time_str = f"{start_time:.2f}".replace('.', 'p') base_img_name_trigger = f"{base_filename}_ch{ch_index}_{annotated_channel_name}_{safe_time_str}s_{label_description}" # --- 准备元数据条目,路径将稍后填充 --- img_success_paths_nh_all = [] # List for this type images_generated_this_trigger = 0 # Count images for this *specific* trigger annotation common_plot_args = { "y_scaling_mode": Y_SCALING_MODE, "fixed_y_scale_uv": FIXED_Y_SCALE_UV, "y_padding_factor": Y_PADDING_FACTOR, "fig_width": FIG_WIDTH_INCHES, "dpi": DPI, "line_color": LINE_COLOR, "line_width": LINE_WIDTH, "zero_line_color": ZERO_LINE_COLOR, "zero_line_style": ZERO_LINE_STYLE, "zero_line_width": ZERO_LINE_WIDTH, "voltage_label_precision": VOLTAGE_LABEL_PRECISION } # --- 生成图像 --- n_channels_total = data_cropped.shape[0] num_batches = math.ceil(n_channels_total / CHANNELS_PER_BATCH_IMAGE) # print("num_batches",num_batches) # 1. No Highlight - All Channels (Batched) if not args.disable_no_hl_all: for i_batch in range(num_batches): # print("ibatch",i_batch) start_idx = i_batch * CHANNELS_PER_BATCH_IMAGE end_idx = min((i_batch + 1) * CHANNELS_PER_BATCH_IMAGE, n_channels_total) data_batch = data_cropped[start_idx:end_idx, :] ch_names_batch = cropped_raw.ch_names[start_idx:end_idx] if data_batch.size == 0: continue # Skip empty batch batch_suffix = f"_batch{i_batch+1}of{num_batches}" if num_batches > 1 else "" save_path_nh_all_batch = os.path.abspath(os.path.join(output_dirs["no_hl_all"], f"{base_img_name_trigger}_noHL_ALL{batch_suffix}.png")) # Pass the batch data and names success_nh_a = plot_eeg_matplotlib_flexible_scale( data=data_batch, times=times_cropped, ch_names=ch_names_batch, title="", save_path=save_path_nh_all_batch, fig_height_per_channel=FIG_HEIGHT_PER_CHANNEL, **common_plot_args ) if success_nh_a: images_generated_this_trigger += 1 img_success_paths_nh_all.append(save_path_nh_all_batch) # Append path to list # --- 存储元数据 (仅当为此触发器生成了至少一张图片时) --- if images_generated_this_trigger > 0: metadata_entry = { "original_edf_abs": os.path.abspath(edf_path), "original_rec_abs": os.path.abspath(rec_path), "triggering_rec_line": rec_line, "triggering_annotation": trigger_annot, "plot_window_start_sec": plot_start, "plot_window_end_sec": plot_end, "actual_plot_duration_sec": actual_plot_duration, # --- 使用新的键名存储路径列表或单个路径 --- "image_paths_no_highlight_all_abs": img_success_paths_nh_all if img_success_paths_nh_all else None, # Store list or None } folder_metadata_list.append(metadata_entry) generated_images_count += images_generated_this_trigger # Add to folder's total image count # End annotation loop processed_files_count += 1 # Increment after successfully processing file (incl. annotations) # End file loop # --- 保存当前文件夹的统计信息 JSON --- folder_summary_path = os.path.join(output_folder_base, FOLDER_JSON_FILENAME) # Include the batch size in the summary folder_summary_data = { "input_folder": os.path.abspath(input_folder), "output_folder": os.path.abspath(output_folder_base), "channels_per_batched_image": CHANNELS_PER_BATCH_IMAGE, # Add info "files_processed": processed_files_count, "images_generated": generated_images_count, # This now counts individual image files "annotation_counts": dict(folder_stats), "image_metadata": folder_metadata_list # This list contains entries per trigger-annotation } try: if processed_files_count > 0 or generated_images_count > 0 : os.makedirs(output_folder_base, exist_ok=True) # Ensure base output dir exists with open(folder_summary_path, 'w') as f_json: # Use default=str to handle potential non-serializable types if any sneak in json.dump(folder_summary_data, f_json, indent=4, default=str) except Exception as e: tqdm.write(f" Error: Failed to save folder summary JSON '{folder_summary_path}': {e}") return folder_metadata_list, folder_stats, processed_files_count, generated_images_count # --- 主递归处理函数 --- def process_directory_recursive(root_input_dir, root_output_dir, args): """ 递归处理输入根目录下的所有子文件夹。 Passes args down to control image generation. """ print(f"Scanning for data folders in: {root_input_dir}...") global_metadata_list = [] global_stats = Counter({label: 0 for label in all_possible_labels}) total_files_processed_global = 0 total_images_generated_global = 0 # Counts individual image files data_folders_to_process = [] for subdir, dirs, files in os.walk(root_input_dir): # Check if *any* file in the current directory ends with .edf if any(fname.lower().endswith('.edf') for fname in files): # Also check if a corresponding .rec file exists for at least one .edf has_rec = False for fname in files: if fname.lower().endswith('.edf'): base = fname[:-4] if os.path.exists(os.path.join(subdir, base + '.rec')): has_rec = True break if has_rec: data_folders_to_process.append(subdir) else: print(f" Skipping folder (no .rec found): {subdir}") if not data_folders_to_process: print("No folders containing corresponding .edf and .rec files found.") return print(f"Found {len(data_folders_to_process)} data folder(s) with EDF/REC pairs to process.") print(f"Outputting batched 'all channels' images with {CHANNELS_PER_BATCH_IMAGE} channels per image.") # Inform user print(f"Image generation settings:") print(f" - No Highlight All (Batched): {'Enabled' if not args.disable_no_hl_all else 'Disabled'}") processed_folders_count = 0 # Wrap data_folders_to_process with tqdm for overall progress bar for subdir in tqdm(data_folders_to_process, desc="Overall Progress", unit="folder"): processed_folders_count += 1 # Construct output path relative to root_output_dir relative_path = os.path.relpath(subdir, root_input_dir) current_output_dir = os.path.join(root_output_dir, relative_path) # <<< Pass args down >>> folder_metadata, folder_stats, folder_files_processed, folder_images_generated = process_single_folder( subdir, current_output_dir, args ) global_metadata_list.extend(folder_metadata) global_stats.update(folder_stats) total_files_processed_global += folder_files_processed total_images_generated_global += folder_images_generated # Accumulate individual image counts # --- 保存全局 JSON 元数据和统计信息 --- if global_metadata_list or processed_folders_count > 0: global_json_path = os.path.join(root_output_dir, GLOBAL_JSON_FILENAME) # Base structure (remains similar) global_summary_data = { "root_input_directory": os.path.abspath(root_input_dir), "root_output_directory": os.path.abspath(root_output_dir), "channels_per_batched_image": CHANNELS_PER_BATCH_IMAGE, # Add info "total_folders_processed": processed_folders_count, "total_edf_rec_pairs_processed": total_files_processed_global, "total_images_generated": total_images_generated_global, # Total individual images # Add individual counts here } # Add counts with "count_" prefix () for label, count in global_stats.items(): global_summary_data[f"count_{label}"] = count for label in all_possible_labels: if f"count_{label}" not in global_summary_data: global_summary_data[f"count_{label}"] = 0 # Add the list of all metadata entries at the end () # Each entry in this list now potentially contains lists of image paths global_summary_data["all_image_metadata"] = global_metadata_list try: os.makedirs(root_output_dir, exist_ok=True) with open(global_json_path, 'w') as f_json: # Use default=str for safety json.dump(global_summary_data, f_json, indent=4, default=str) print(f"\n--- === Global Processing Complete === ---") print(f"Processed {processed_folders_count} data folder(s).") print(f"Processed a total of {total_files_processed_global} EDF/REC file pairs.") # Updated message reflects individual image files print(f"Generated a total of {total_images_generated_global} individual image files (including batches).") print(f"Global summary and metadata saved to: {global_json_path}") print(f"Global annotation counts:") for label in sorted(all_possible_labels): print(f" - {label}: {global_stats[label]}") except Exception as e: print(f"\nError: Failed to save global JSON summary '{global_json_path}': {e}") else: print("\n--- === Global Processing Complete === ---") print("No data folders were found or processed.") # --- 执行主程序 --- if __name__ == "__main__": # --- Argument Parsing () --- parser = argparse.ArgumentParser(description="Generate cropped EEG images from EDF/REC pairs in a directory structure, optionally batching multi-channel plots.") parser.add_argument('-i', '--input-dir', type=str, default=ROOT_INPUT_DIR, help=f"Root directory containing the input data folders (e.g., eval/, train/). Default: {ROOT_INPUT_DIR}") parser.add_argument('-o', '--output-dir', type=str, default=ROOT_OUTPUT_DIR, help=f"Root directory where output images and JSON summaries will be saved. Default: {ROOT_OUTPUT_DIR}") # Flags to disable specific image types (default is Enabled) # Defaults are changed to typically enable only the batched 'no highlight all' parser.add_argument('--disable-no-hl-all', action='store_true', default=False, help="Disable generation of 'no highlight, all channels' (batched) images.") args = parser.parse_args() # Use directories from args root_input_dir = args.input_dir root_output_dir = args.output_dir if not os.path.isdir(root_input_dir): print(f"错误:输入根目录 '{root_input_dir}' 不存在或不是一个目录。请检查 --input-dir 参数或脚本中的 ROOT_INPUT_DIR。") # Exit if input directory is invalid sys.exit(1) else: # Ensure base output directory exists before starting os.makedirs(root_output_dir, exist_ok=True) # <<< Pass args to the main processing function >>> process_directory_recursive(root_input_dir, root_output_dir, args)