compliance/1.py
2025-05-16 15:18:02 +08:00

589 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import mne
import os
import json
import glob
import matplotlib.pyplot as plt
import sys
import numpy as np
import math # For ceil
from collections import Counter # For counting annotations
from tqdm import tqdm # <<< Import tqdm
import argparse
import matplotlib.ticker as ticker
import matplotlib.transforms as transforms
import warnings
# --- 用户配置 ---
ROOT_INPUT_DIR = "/data/shuiyuan/workplace/EEG/isip.piconepress.com/data/tuh_eeg/tuh_eeg_events/v2.0.1/edf/"
ROOT_OUTPUT_DIR = "/data1/gongwenxin/workspace/eeg/dataset/tuh/eeg_annotation_plots_filtered_5s_resample/"
OUTPUT_SUBDIR_NO_HIGHLIGHT_ALL = "no_highlight/all_channels_batched" # 唯一保留的图像输出子目录
FOLDER_JSON_FILENAME = "_folder_summary.json"
GLOBAL_JSON_FILENAME = "global_summary_metadata.json"
# --- 新增:分批绘图参数 ---
CHANNELS_PER_BATCH_IMAGE = 5 # 每张"所有通道"图像包含的通道数
# --- 绘图参数 (保持) ---
PLOT_BUFFER_SECONDS = 2
MIN_PLOT_DURATION = 1.0
FIG_WIDTH_INCHES = 12
FIG_HEIGHT_PER_CHANNEL = 2 # 保持较大的高度以减少标签重叠
DPI = 100
LINE_COLOR = 'black'
LINE_WIDTH = 0.5
ZERO_LINE_COLOR = 'gray'
ZERO_LINE_STYLE = '-'
ZERO_LINE_WIDTH = 0.5
VOLTAGE_LABEL_PRECISION = 0
Y_SCALING_MODE = 'fixed'
FIXED_Y_SCALE_UV = 150
Y_PADDING_FACTOR = 0.1
# --- TCP Montage 定义 (保持) ---
anodes = ['EEG FP1-REF', 'EEG F7-REF', 'EEG T3-REF', 'EEG T5-REF', 'EEG FP2-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG T6-REF', 'EEG A1-REF', 'EEG T3-REF', 'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 'EEG FP1-REF', 'EEG F3-REF', 'EEG C3-REF', 'EEG P3-REF', 'EEG FP2-REF', 'EEG F4-REF', 'EEG C4-REF', 'EEG P4-REF']
cathodes = ['EEG F7-REF', 'EEG T3-REF', 'EEG T5-REF', 'EEG O1-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG T6-REF', 'EEG O2-REF', 'EEG T3-REF', 'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 'EEG A2-REF', 'EEG F3-REF', 'EEG C3-REF', 'EEG P3-REF', 'EEG O1-REF', 'EEG F4-REF', 'EEG C4-REF', 'EEG P4-REF', 'EEG O2-REF']
ch_names_tcp = ['FP1-F7', 'F7-T3', 'T3-T5', 'T5-O1', 'FP2-F8', 'F8-T4', 'T4-T6', 'T6-O2', 'A1-T3', 'T3-C3', 'C3-CZ', 'CZ-C4', 'C4-T4', 'T4-A2', 'FP1-F3', 'F3-C3', 'C3-P3', 'P3-O1', 'FP2-F4', 'F4-C4', 'C4-P4', 'P4-O2']
# REC 文件标签代码到描述的映射 (保持)
rec_label_map = { 1: 'spsw', 2: 'gped', 3: 'pled', 4: 'eyem', 5: 'artf', 6: 'bckg' }
all_possible_labels = list(rec_label_map.values())
def plot_eeg_matplotlib_flexible_scale(
data, times, ch_names, title, save_path,
y_scaling_mode='fixed', fixed_y_scale_uv=100, y_padding_factor=0.1,
fig_width=16, fig_height_per_channel=0.6, dpi=100,
line_color='black', line_width=0.5,
zero_line_color='dimgray', zero_line_style='-', zero_line_width=0.6,
guide_line_color='lightgray', guide_line_style='-', guide_line_width=0.5,
show_channel_voltage_labels=True,
channel_voltage_label_color='gray',
voltage_label_precision=0,
x_tick_interval=0.1
):
"""
绘制EEG数据所有通道在同一Axes上垂直偏移。
每个通道绘制基线("0 uV"标签)、辅助水平线,并可选地在其顶部/底部辅助线上添加电压标签。
警告: 当 show_channel_voltage_labels=True 时,如果通道密集,标签可能重叠。
Args:
# ... (previous args) ...
zero_line_color (str): 每个通道偏移基线 (0 uV线) 的颜色。
zero_line_style (str): 每个通道偏移基线的样式。
zero_line_width (float): 每个通道偏移基线的宽度。
guide_line_color (str): 辅助水平刻度线的颜色。
guide_line_style (str): 辅助水平刻度线的样式。
guide_line_width (float): 辅助水平刻度线的宽度。
show_channel_voltage_labels (bool): 是否在每个通道的顶部/底部辅助线旁添加电压标签。
channel_voltage_label_color (str): 通道电压标签 (+/- 值) 的颜色。
# ... (remaining args) ...
Returns:
bool: True if successful, False otherwise.
"""
n_channels = data.shape[0]
if n_channels == 0:
warnings.warn("数据中没有通道,无法绘图。")
return False
total_fig_height = max(n_channels * fig_height_per_channel + 1.5, 4.0)
fig, ax = plt.subplots(1, 1, figsize=(fig_width, total_fig_height), dpi=dpi)
plot_start_time = times[0]
plot_end_time = times[-1]
actual_plot_duration = plot_end_time - plot_start_time
data_uv = data * 1e6
# --- Y轴缩放和偏移计算 ---
channel_display_scale_uv = abs(fixed_y_scale_uv)
if y_scaling_mode == 'global_auto' or y_scaling_mode == 'individual_auto':
if y_scaling_mode == 'individual_auto':
warnings.warn("'individual_auto' scaling is not suitable for stacked plots. Using 'global_auto' behavior.")
if data_uv.size > 0:
global_max_abs = np.max(np.abs(data_uv))
if global_max_abs > 1e-6:
channel_display_scale_uv = global_max_abs * (1 + y_padding_factor)
# else: use fixed_y_scale_uv default
# else: use fixed_y_scale_uv default
if channel_display_scale_uv <= 0:
warnings.warn(f"计算出的 channel_display_scale_uv ({channel_display_scale_uv:.2f}) 无效,将使用默认值 100uV。")
channel_display_scale_uv = 100.0
# offset_step = channel_display_scale_uv * 2.5
fixed_gap_uv = 50.0
offset_step = 2 * channel_display_scale_uv + fixed_gap_uv
y_offsets = [(n_channels - 1 - i) * offset_step for i in range(n_channels)]
guide_levels_uv = [ # Levels relative to baseline offset
channel_display_scale_uv, channel_display_scale_uv * 2/3, channel_display_scale_uv * 1/3,
-channel_display_scale_uv * 1/3, -channel_display_scale_uv * 2/3, -channel_display_scale_uv
]
label_format = f"{{:+.{voltage_label_precision}f}} μV" # Format with sign
# --- 绘图循环 ---
for i in range(n_channels):
ch_name = ch_names[i]
channel_data = data_uv[i]
offset = y_offsets[i]
# 1. 信号
ax.plot(times, channel_data + offset, color=line_color, linewidth=line_width, zorder=3)
# 2. 基线 (0 uV)
ax.axhline(offset, color=zero_line_color, linestyle=zero_line_style, linewidth=zero_line_width, zorder=2, alpha=0.9)
trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
ax.text(1.001, offset, " 0 μV", transform=trans,
va='center', ha='left', fontsize=6, color=zero_line_color, zorder=2)
# 3. 辅助刻度线
for level_uv in guide_levels_uv:
ax.axhline(offset + level_uv, color=guide_line_color, linestyle=guide_line_style,
linewidth=guide_line_width, alpha=0.7, zorder=1)
scale_bar_uv_val = channel_display_scale_uv
if scale_bar_uv_val <=0: scale_bar_uv_val = abs(fixed_y_scale_uv)
if scale_bar_uv_val <=0: scale_bar_uv_val = 100.0
scale_bar_x_pos = plot_end_time - 0.02 * actual_plot_duration
scale_bar_y_center = ax.get_ylim()[1] - offset_step * 0.5
scale_bar_y_start = scale_bar_y_center - scale_bar_uv_val / 2.0
scale_bar_y_end = scale_bar_y_center + scale_bar_uv_val / 2.0
# 4. 可选:在顶部/底部辅助线旁添加电压标签
if show_channel_voltage_labels:
# Top label (+scale uV)
y_pos_top = offset + channel_display_scale_uv
ax.text( 1.001 , y_pos_top, # Position similar to '0 uV' label
label_format.format(channel_display_scale_uv), # Format with '+' sign
transform=trans,
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
y_pos_top31 = offset + channel_display_scale_uv*1/3
ax.text( 1.001, y_pos_top31, # Position similar to '0 uV' label
label_format.format(channel_display_scale_uv*1/3), # Format with '+' sign
transform=trans,
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
y_pos_top32 = offset + channel_display_scale_uv*2/3
ax.text( 1.001, y_pos_top32, # Position similar to '0 uV' label
label_format.format(channel_display_scale_uv*2/3), # Format with '+' sign
transform=trans,
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
# Bottom label (-scale uV)
y_pos_bottom = offset - channel_display_scale_uv
ax.text( 1.001 , y_pos_bottom, # Position similar to '0 uV' label
label_format.format(-channel_display_scale_uv), # Format with '-' sign
transform=trans,
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
y_pos_topf31 = offset -channel_display_scale_uv*1/3
ax.text( 1.001, y_pos_topf31, # Position similar to '0 uV' label
label_format.format(-channel_display_scale_uv*1/3), # Format with '+' sign
transform=trans,
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
y_pos_topf32 = offset -channel_display_scale_uv*2/3
ax.text( 1.001, y_pos_topf32, # Position similar to '0 uV' label
label_format.format(-channel_display_scale_uv*2/3), # Format with '+' sign
transform=trans,
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
# --- X轴和网格线 ---
ax.set_xlim(plot_start_time, plot_end_time)
ax.set_xlabel("Time (s)", fontsize=9)
ax.tick_params(axis='x', labelsize=7)
if actual_plot_duration > 0 and x_tick_interval is not None and x_tick_interval > 0:
# --- Major Grid Lines ---
# Set major ticks location (e.g., every 1 second if x_tick_interval is 0.1)
major_locator = ticker.MultipleLocator(base=x_tick_interval * 10)
ax.xaxis.set_major_locator(major_locator)
# Enable and style the major grid
ax.xaxis.grid(True, which='major', linestyle='-', linewidth=0.5, color='darkgrey', alpha=0.7, zorder=0)
# --- Minor Grid Lines ---
# Set minor ticks location (e.g., every 0.1 seconds if x_tick_interval is 0.1)
minor_locator = ticker.MultipleLocator(base=x_tick_interval)
ax.xaxis.set_minor_locator(minor_locator)
# Enable and style the minor grid (use a different style)
ax.xaxis.grid(True, which='minor', linestyle='-', linewidth=0.5, color='lightgray', alpha=0.7, zorder=0) # Example style
# Ensure Y-axis grid remains off
ax.yaxis.grid(False)
else:
# Disable both X and Y grids if condition not met
ax.xaxis.grid(False)
ax.yaxis.grid(False)
# --- Y轴配置 ---
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False); ax.spines['bottom'].set_visible(True)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=True)
ax.set_yticks(y_offsets)
ax.set_yticklabels(ch_names, fontsize=8, va='center')
y_min_limit = y_offsets[-1] + min(guide_levels_uv) - channel_display_scale_uv * 0.2
y_max_limit = y_offsets[0] + max(guide_levels_uv) + channel_display_scale_uv * 0.2
if y_max_limit <= y_min_limit: y_max_limit = y_min_limit + offset_step
ax.set_ylim(y_min_limit, y_max_limit)
# --- 标题和布局 ---
fig.suptitle(title, fontsize=12)
# Potentially need more left margin for the added voltage labels
plt.tight_layout() # Increased left margin slightly more
# --- 保存图像 ---
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path)
plt.close(fig)
return True
except Exception as e:
print(f" 错误: 保存图像到 '{save_path}' 时失败: {e}")
plt.close(fig)
return False
# --- 单个文件夹处理函数 ---
def process_single_folder(input_folder, output_folder_base, args):
"""
处理单个文件夹中的 EEG 文件和注释, 生成图像和元数据。
对于 'all_channels' 类型的图像,会按 CHANNELS_PER_BATCH_IMAGE 分批生成。
"""
folder_name = os.path.basename(input_folder)
# 更新输出子目录名称 (可选)
output_dirs = {
"no_hl_all": os.path.join(output_folder_base, OUTPUT_SUBDIR_NO_HIGHLIGHT_ALL),
}
# 创建目录 ()
if not args.disable_no_hl_all: os.makedirs(output_dirs["no_hl_all"], exist_ok=True)
folder_metadata_list = []
folder_stats = Counter({label: 0 for label in all_possible_labels})
processed_files_count = 0
generated_images_count = 0 # 总图像文件数
edf_files = sorted(glob.glob(os.path.join(input_folder, '*.edf')))
if not edf_files: return [], folder_stats, 0, 0
for edf_path in tqdm(edf_files, desc=f"Processing {folder_name}", unit="file", leave=False):
base_filename = os.path.basename(edf_path)[:-4]
rec_path = os.path.join(input_folder, base_filename + '.rec')
if not os.path.exists(rec_path):
tqdm.write(f" Warning: REC file not found for '{base_filename}.edf' in {folder_name}. Skipping pair.")
continue
# --- 1. 加载和预处理 EEG 数据 () ---
try:
# ... (加载和预处理逻辑) ...
with mne.utils.use_log_level('WARNING'):
raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
print("采样频率:", raw.info['sfreq'], "Hz") # 打印采样频率
print("信息:", raw.info) # 打印采样频率
raw.filter(l_freq=0.1, h_freq=75.0)
raw.notch_filter(50.0) # Notch filter for power line noise
raw.resample(100, n_jobs=5)
raw.rename_channels(lambda s: s.replace('EEG ', '').replace('-REF', '').replace('.', '').upper(), verbose=False)
anodes_std = [ch.replace('EEG ', '').replace('-REF', '').replace('.', '').upper() for ch in anodes]
cathodes_std = [ch.replace('EEG ', '').replace('-REF', '').replace('.', '').upper() for ch in cathodes]
missing_channels = set(anodes_std + cathodes_std) - set(raw.ch_names)
if missing_channels:
tqdm.write(f" Error: EDF {base_filename} missing required TCP channels: {missing_channels}. Skipping.")
continue
with mne.utils.use_log_level('ERROR'):
try: raw.set_montage(mne.channels.make_standard_montage('standard_1020'), on_missing='warn', verbose=False)
except Exception: pass # Ignore if montage fails
raw_tcp = mne.set_bipolar_reference(raw, anode=anodes_std, cathode=cathodes_std, ch_name=ch_names_tcp, copy=True, verbose=False)
raw_tcp.pick_channels(ch_names_tcp) # Ensure only TCP channels remain
except Exception as e:
tqdm.write(f" Error: Failed to load/preprocess EDF '{base_filename}' in {folder_name}: {e}. Skipping.")
continue
# --- 2. 解析整个 REC 文件 () ---
all_annotations_in_file = []
current_file_labels = []
try:
# ... (REC 解析逻辑) ...
with open(rec_path, 'r') as f_rec:
# print(rec_path)
for line_num, line in enumerate(f_rec):
line = line.strip(); parts = line.split(',')
if not line or len(parts) != 4: continue
try:
ch_idx = int(parts[0].strip()); start = float(parts[1].strip()); stop = float(parts[2].strip()); lbl_code = int(parts[3].strip())
duration = stop - start
# Validate channel index against *current* raw_tcp channels (should be ch_names_tcp)
if not (0 <= ch_idx < len(raw_tcp.ch_names)) or duration <= 0 or lbl_code not in rec_label_map: continue
label_desc = rec_label_map[lbl_code]
# Use the actual channel name from raw_tcp after picking
actual_ch_name = raw_tcp.ch_names[ch_idx]
all_annotations_in_file.append({
"ch_index": ch_idx, # Index within the final TCP montage
"ch_name": actual_ch_name, # Actual name in raw_tcp
"start": start, "stop": stop, "duration": duration,
"label_code": lbl_code, "label": label_desc, "rec_line": line_num + 1
})
current_file_labels.append(label_desc)
except ValueError: continue # Ignore lines with non-numeric values where expected
if not all_annotations_in_file:
processed_files_count += 1 # Still count as processed even if no annotations
continue # Skip image generation if no annotations
folder_stats.update(current_file_labels)
except Exception as e:
tqdm.write(f" Error: Failed to read/parse REC '{base_filename}.rec' in {folder_name}: {e}. Skipping file pair.")
continue
# --- 3. 遍历每个标注作为"触发器"来生成图像 ---
for trigger_annot in all_annotations_in_file:
start_time = trigger_annot['start']; stop_time = trigger_annot['stop']; duration = trigger_annot['duration']
annotated_channel_name = trigger_annot['ch_name']; label_description = trigger_annot['label']
# Use the index directly from the annotation (it's already relative to raw_tcp)
ch_index = trigger_annot['ch_index']
rec_line = trigger_annot['rec_line']
# --- 计算绘图时间窗口和截取数据 () ---
plot_duration = max(duration + 2 * PLOT_BUFFER_SECONDS, MIN_PLOT_DURATION)
plot_center = start_time + duration / 2.0
plot_start = max(0, plot_center - plot_duration / 2.0)
# Ensure plot_end doesn't exceed data length
plot_end = min(raw_tcp.times[-1], plot_center + plot_duration / 2.0)
actual_plot_duration = plot_end - plot_start
if actual_plot_duration < 0.1: # Skip very short or invalid intervals
tqdm.write(f" Skipping annotation at {start_time:.2f}s: plot duration too short ({actual_plot_duration:.3f}s).")
continue
try:
with mne.utils.use_log_level('ERROR'):
# Crop the already processed raw_tcp object
cropped_raw = raw_tcp.copy().crop(tmin=plot_start, tmax=plot_end, include_tmax=True, verbose=False)
data_cropped, times_cropped = cropped_raw.get_data(return_times=True)
if data_cropped.size == 0 or times_cropped.size == 0:
tqdm.write(f" Skipping annotation at {start_time:.2f}s: Cropped data is empty.")
continue
except Exception as crop_err:
tqdm.write(f" Error cropping data for annotation at {start_time:.2f}s: {crop_err}. Skipping annotation.")
continue
# --- 准备基础文件名 (不含批次信息) ---
safe_time_str = f"{start_time:.2f}".replace('.', 'p')
base_img_name_trigger = f"{base_filename}_ch{ch_index}_{annotated_channel_name}_{safe_time_str}s_{label_description}"
# --- 准备元数据条目,路径将稍后填充 ---
img_success_paths_nh_all = [] # List for this type
images_generated_this_trigger = 0 # Count images for this *specific* trigger annotation
common_plot_args = { "y_scaling_mode": Y_SCALING_MODE, "fixed_y_scale_uv": FIXED_Y_SCALE_UV, "y_padding_factor": Y_PADDING_FACTOR, "fig_width": FIG_WIDTH_INCHES, "dpi": DPI, "line_color": LINE_COLOR, "line_width": LINE_WIDTH, "zero_line_color": ZERO_LINE_COLOR, "zero_line_style": ZERO_LINE_STYLE, "zero_line_width": ZERO_LINE_WIDTH, "voltage_label_precision": VOLTAGE_LABEL_PRECISION }
# --- 生成图像 ---
n_channels_total = data_cropped.shape[0]
num_batches = math.ceil(n_channels_total / CHANNELS_PER_BATCH_IMAGE)
# print("num_batches",num_batches)
# 1. No Highlight - All Channels (Batched)
if not args.disable_no_hl_all:
for i_batch in range(num_batches):
# print("ibatch",i_batch)
start_idx = i_batch * CHANNELS_PER_BATCH_IMAGE
end_idx = min((i_batch + 1) * CHANNELS_PER_BATCH_IMAGE, n_channels_total)
data_batch = data_cropped[start_idx:end_idx, :]
ch_names_batch = cropped_raw.ch_names[start_idx:end_idx]
if data_batch.size == 0: continue # Skip empty batch
batch_suffix = f"_batch{i_batch+1}of{num_batches}" if num_batches > 1 else ""
save_path_nh_all_batch = os.path.abspath(os.path.join(output_dirs["no_hl_all"], f"{base_img_name_trigger}_noHL_ALL{batch_suffix}.png"))
# Pass the batch data and names
success_nh_a = plot_eeg_matplotlib_flexible_scale(
data=data_batch, times=times_cropped, ch_names=ch_names_batch,
title="", save_path=save_path_nh_all_batch,
fig_height_per_channel=FIG_HEIGHT_PER_CHANNEL,
**common_plot_args
)
if success_nh_a:
images_generated_this_trigger += 1
img_success_paths_nh_all.append(save_path_nh_all_batch) # Append path to list
# --- 存储元数据 (仅当为此触发器生成了至少一张图片时) ---
if images_generated_this_trigger > 0:
metadata_entry = {
"original_edf_abs": os.path.abspath(edf_path),
"original_rec_abs": os.path.abspath(rec_path),
"triggering_rec_line": rec_line,
"triggering_annotation": trigger_annot,
"plot_window_start_sec": plot_start,
"plot_window_end_sec": plot_end,
"actual_plot_duration_sec": actual_plot_duration,
# --- 使用新的键名存储路径列表或单个路径 ---
"image_paths_no_highlight_all_abs": img_success_paths_nh_all if img_success_paths_nh_all else None, # Store list or None
}
folder_metadata_list.append(metadata_entry)
generated_images_count += images_generated_this_trigger # Add to folder's total image count
# End annotation loop
processed_files_count += 1 # Increment after successfully processing file (incl. annotations)
# End file loop
# --- 保存当前文件夹的统计信息 JSON ---
folder_summary_path = os.path.join(output_folder_base, FOLDER_JSON_FILENAME)
# Include the batch size in the summary
folder_summary_data = {
"input_folder": os.path.abspath(input_folder),
"output_folder": os.path.abspath(output_folder_base),
"channels_per_batched_image": CHANNELS_PER_BATCH_IMAGE, # Add info
"files_processed": processed_files_count,
"images_generated": generated_images_count, # This now counts individual image files
"annotation_counts": dict(folder_stats),
"image_metadata": folder_metadata_list # This list contains entries per trigger-annotation
}
try:
if processed_files_count > 0 or generated_images_count > 0 :
os.makedirs(output_folder_base, exist_ok=True) # Ensure base output dir exists
with open(folder_summary_path, 'w') as f_json:
# Use default=str to handle potential non-serializable types if any sneak in
json.dump(folder_summary_data, f_json, indent=4, default=str)
except Exception as e:
tqdm.write(f" Error: Failed to save folder summary JSON '{folder_summary_path}': {e}")
return folder_metadata_list, folder_stats, processed_files_count, generated_images_count
# --- 主递归处理函数 ---
def process_directory_recursive(root_input_dir, root_output_dir, args):
"""
递归处理输入根目录下的所有子文件夹。
Passes args down to control image generation.
"""
print(f"Scanning for data folders in: {root_input_dir}...")
global_metadata_list = []
global_stats = Counter({label: 0 for label in all_possible_labels})
total_files_processed_global = 0
total_images_generated_global = 0 # Counts individual image files
data_folders_to_process = []
for subdir, dirs, files in os.walk(root_input_dir):
# Check if *any* file in the current directory ends with .edf
if any(fname.lower().endswith('.edf') for fname in files):
# Also check if a corresponding .rec file exists for at least one .edf
has_rec = False
for fname in files:
if fname.lower().endswith('.edf'):
base = fname[:-4]
if os.path.exists(os.path.join(subdir, base + '.rec')):
has_rec = True
break
if has_rec:
data_folders_to_process.append(subdir)
else:
print(f" Skipping folder (no .rec found): {subdir}")
if not data_folders_to_process:
print("No folders containing corresponding .edf and .rec files found.")
return
print(f"Found {len(data_folders_to_process)} data folder(s) with EDF/REC pairs to process.")
print(f"Outputting batched 'all channels' images with {CHANNELS_PER_BATCH_IMAGE} channels per image.") # Inform user
print(f"Image generation settings:")
print(f" - No Highlight All (Batched): {'Enabled' if not args.disable_no_hl_all else 'Disabled'}")
processed_folders_count = 0
# Wrap data_folders_to_process with tqdm for overall progress bar
for subdir in tqdm(data_folders_to_process, desc="Overall Progress", unit="folder"):
processed_folders_count += 1
# Construct output path relative to root_output_dir
relative_path = os.path.relpath(subdir, root_input_dir)
current_output_dir = os.path.join(root_output_dir, relative_path)
# <<< Pass args down >>>
folder_metadata, folder_stats, folder_files_processed, folder_images_generated = process_single_folder(
subdir, current_output_dir, args
)
global_metadata_list.extend(folder_metadata)
global_stats.update(folder_stats)
total_files_processed_global += folder_files_processed
total_images_generated_global += folder_images_generated # Accumulate individual image counts
# --- 保存全局 JSON 元数据和统计信息 ---
if global_metadata_list or processed_folders_count > 0:
global_json_path = os.path.join(root_output_dir, GLOBAL_JSON_FILENAME)
# Base structure (remains similar)
global_summary_data = {
"root_input_directory": os.path.abspath(root_input_dir),
"root_output_directory": os.path.abspath(root_output_dir),
"channels_per_batched_image": CHANNELS_PER_BATCH_IMAGE, # Add info
"total_folders_processed": processed_folders_count,
"total_edf_rec_pairs_processed": total_files_processed_global,
"total_images_generated": total_images_generated_global, # Total individual images
# Add individual counts here
}
# Add counts with "count_" prefix ()
for label, count in global_stats.items():
global_summary_data[f"count_{label}"] = count
for label in all_possible_labels:
if f"count_{label}" not in global_summary_data:
global_summary_data[f"count_{label}"] = 0
# Add the list of all metadata entries at the end ()
# Each entry in this list now potentially contains lists of image paths
global_summary_data["all_image_metadata"] = global_metadata_list
try:
os.makedirs(root_output_dir, exist_ok=True)
with open(global_json_path, 'w') as f_json:
# Use default=str for safety
json.dump(global_summary_data, f_json, indent=4, default=str)
print(f"\n--- === Global Processing Complete === ---")
print(f"Processed {processed_folders_count} data folder(s).")
print(f"Processed a total of {total_files_processed_global} EDF/REC file pairs.")
# Updated message reflects individual image files
print(f"Generated a total of {total_images_generated_global} individual image files (including batches).")
print(f"Global summary and metadata saved to: {global_json_path}")
print(f"Global annotation counts:")
for label in sorted(all_possible_labels):
print(f" - {label}: {global_stats[label]}")
except Exception as e:
print(f"\nError: Failed to save global JSON summary '{global_json_path}': {e}")
else:
print("\n--- === Global Processing Complete === ---")
print("No data folders were found or processed.")
# --- 执行主程序 ---
if __name__ == "__main__":
# --- Argument Parsing () ---
parser = argparse.ArgumentParser(description="Generate cropped EEG images from EDF/REC pairs in a directory structure, optionally batching multi-channel plots.")
parser.add_argument('-i', '--input-dir', type=str, default=ROOT_INPUT_DIR,
help=f"Root directory containing the input data folders (e.g., eval/, train/). Default: {ROOT_INPUT_DIR}")
parser.add_argument('-o', '--output-dir', type=str, default=ROOT_OUTPUT_DIR,
help=f"Root directory where output images and JSON summaries will be saved. Default: {ROOT_OUTPUT_DIR}")
# Flags to disable specific image types (default is Enabled)
# Defaults are changed to typically enable only the batched 'no highlight all'
parser.add_argument('--disable-no-hl-all', action='store_true', default=False,
help="Disable generation of 'no highlight, all channels' (batched) images.")
args = parser.parse_args()
# Use directories from args
root_input_dir = args.input_dir
root_output_dir = args.output_dir
if not os.path.isdir(root_input_dir):
print(f"错误:输入根目录 '{root_input_dir}' 不存在或不是一个目录。请检查 --input-dir 参数或脚本中的 ROOT_INPUT_DIR。")
# Exit if input directory is invalid
sys.exit(1)
else:
# Ensure base output directory exists before starting
os.makedirs(root_output_dir, exist_ok=True)
# <<< Pass args to the main processing function >>>
process_directory_recursive(root_input_dir, root_output_dir, args)