589 lines
31 KiB
Python
589 lines
31 KiB
Python
import mne
|
||
import os
|
||
import json
|
||
import glob
|
||
import matplotlib.pyplot as plt
|
||
import sys
|
||
import numpy as np
|
||
import math # For ceil
|
||
from collections import Counter # For counting annotations
|
||
from tqdm import tqdm # <<< Import tqdm
|
||
import argparse
|
||
import matplotlib.ticker as ticker
|
||
import matplotlib.transforms as transforms
|
||
import warnings
|
||
# --- 用户配置 ---
|
||
|
||
ROOT_INPUT_DIR = "/data/shuiyuan/workplace/EEG/isip.piconepress.com/data/tuh_eeg/tuh_eeg_events/v2.0.1/edf/"
|
||
ROOT_OUTPUT_DIR = "/data1/gongwenxin/workspace/eeg/dataset/tuh/eeg_annotation_plots_filtered_5s_resample/"
|
||
|
||
OUTPUT_SUBDIR_NO_HIGHLIGHT_ALL = "no_highlight/all_channels_batched" # 唯一保留的图像输出子目录
|
||
FOLDER_JSON_FILENAME = "_folder_summary.json"
|
||
GLOBAL_JSON_FILENAME = "global_summary_metadata.json"
|
||
|
||
# --- 新增:分批绘图参数 ---
|
||
CHANNELS_PER_BATCH_IMAGE = 5 # 每张"所有通道"图像包含的通道数
|
||
|
||
# --- 绘图参数 (保持) ---
|
||
PLOT_BUFFER_SECONDS = 2
|
||
MIN_PLOT_DURATION = 1.0
|
||
FIG_WIDTH_INCHES = 12
|
||
FIG_HEIGHT_PER_CHANNEL = 2 # 保持较大的高度以减少标签重叠
|
||
DPI = 100
|
||
LINE_COLOR = 'black'
|
||
LINE_WIDTH = 0.5
|
||
ZERO_LINE_COLOR = 'gray'
|
||
ZERO_LINE_STYLE = '-'
|
||
ZERO_LINE_WIDTH = 0.5
|
||
VOLTAGE_LABEL_PRECISION = 0
|
||
Y_SCALING_MODE = 'fixed'
|
||
FIXED_Y_SCALE_UV = 150
|
||
Y_PADDING_FACTOR = 0.1
|
||
|
||
# --- TCP Montage 定义 (保持) ---
|
||
anodes = ['EEG FP1-REF', 'EEG F7-REF', 'EEG T3-REF', 'EEG T5-REF', 'EEG FP2-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG T6-REF', 'EEG A1-REF', 'EEG T3-REF', 'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 'EEG FP1-REF', 'EEG F3-REF', 'EEG C3-REF', 'EEG P3-REF', 'EEG FP2-REF', 'EEG F4-REF', 'EEG C4-REF', 'EEG P4-REF']
|
||
cathodes = ['EEG F7-REF', 'EEG T3-REF', 'EEG T5-REF', 'EEG O1-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG T6-REF', 'EEG O2-REF', 'EEG T3-REF', 'EEG C3-REF', 'EEG CZ-REF', 'EEG C4-REF', 'EEG T4-REF', 'EEG A2-REF', 'EEG F3-REF', 'EEG C3-REF', 'EEG P3-REF', 'EEG O1-REF', 'EEG F4-REF', 'EEG C4-REF', 'EEG P4-REF', 'EEG O2-REF']
|
||
ch_names_tcp = ['FP1-F7', 'F7-T3', 'T3-T5', 'T5-O1', 'FP2-F8', 'F8-T4', 'T4-T6', 'T6-O2', 'A1-T3', 'T3-C3', 'C3-CZ', 'CZ-C4', 'C4-T4', 'T4-A2', 'FP1-F3', 'F3-C3', 'C3-P3', 'P3-O1', 'FP2-F4', 'F4-C4', 'C4-P4', 'P4-O2']
|
||
|
||
# REC 文件标签代码到描述的映射 (保持)
|
||
rec_label_map = { 1: 'spsw', 2: 'gped', 3: 'pled', 4: 'eyem', 5: 'artf', 6: 'bckg' }
|
||
all_possible_labels = list(rec_label_map.values())
|
||
|
||
def plot_eeg_matplotlib_flexible_scale(
|
||
data, times, ch_names, title, save_path,
|
||
y_scaling_mode='fixed', fixed_y_scale_uv=100, y_padding_factor=0.1,
|
||
fig_width=16, fig_height_per_channel=0.6, dpi=100,
|
||
line_color='black', line_width=0.5,
|
||
zero_line_color='dimgray', zero_line_style='-', zero_line_width=0.6,
|
||
guide_line_color='lightgray', guide_line_style='-', guide_line_width=0.5,
|
||
show_channel_voltage_labels=True,
|
||
channel_voltage_label_color='gray',
|
||
voltage_label_precision=0,
|
||
x_tick_interval=0.1
|
||
):
|
||
"""
|
||
绘制EEG数据,所有通道在同一Axes上,垂直偏移。
|
||
每个通道绘制基线("0 uV"标签)、辅助水平线,并可选地在其顶部/底部辅助线上添加电压标签。
|
||
|
||
警告: 当 show_channel_voltage_labels=True 时,如果通道密集,标签可能重叠。
|
||
|
||
Args:
|
||
# ... (previous args) ...
|
||
zero_line_color (str): 每个通道偏移基线 (0 uV线) 的颜色。
|
||
zero_line_style (str): 每个通道偏移基线的样式。
|
||
zero_line_width (float): 每个通道偏移基线的宽度。
|
||
guide_line_color (str): 辅助水平刻度线的颜色。
|
||
guide_line_style (str): 辅助水平刻度线的样式。
|
||
guide_line_width (float): 辅助水平刻度线的宽度。
|
||
show_channel_voltage_labels (bool): 是否在每个通道的顶部/底部辅助线旁添加电压标签。
|
||
channel_voltage_label_color (str): 通道电压标签 (+/- 值) 的颜色。
|
||
# ... (remaining args) ...
|
||
|
||
Returns:
|
||
bool: True if successful, False otherwise.
|
||
"""
|
||
n_channels = data.shape[0]
|
||
if n_channels == 0:
|
||
warnings.warn("数据中没有通道,无法绘图。")
|
||
return False
|
||
|
||
total_fig_height = max(n_channels * fig_height_per_channel + 1.5, 4.0)
|
||
fig, ax = plt.subplots(1, 1, figsize=(fig_width, total_fig_height), dpi=dpi)
|
||
|
||
plot_start_time = times[0]
|
||
plot_end_time = times[-1]
|
||
actual_plot_duration = plot_end_time - plot_start_time
|
||
data_uv = data * 1e6
|
||
|
||
# --- Y轴缩放和偏移计算 ---
|
||
channel_display_scale_uv = abs(fixed_y_scale_uv)
|
||
if y_scaling_mode == 'global_auto' or y_scaling_mode == 'individual_auto':
|
||
if y_scaling_mode == 'individual_auto':
|
||
warnings.warn("'individual_auto' scaling is not suitable for stacked plots. Using 'global_auto' behavior.")
|
||
if data_uv.size > 0:
|
||
global_max_abs = np.max(np.abs(data_uv))
|
||
if global_max_abs > 1e-6:
|
||
channel_display_scale_uv = global_max_abs * (1 + y_padding_factor)
|
||
# else: use fixed_y_scale_uv default
|
||
# else: use fixed_y_scale_uv default
|
||
|
||
if channel_display_scale_uv <= 0:
|
||
warnings.warn(f"计算出的 channel_display_scale_uv ({channel_display_scale_uv:.2f}) 无效,将使用默认值 100uV。")
|
||
channel_display_scale_uv = 100.0
|
||
|
||
# offset_step = channel_display_scale_uv * 2.5
|
||
fixed_gap_uv = 50.0
|
||
offset_step = 2 * channel_display_scale_uv + fixed_gap_uv
|
||
y_offsets = [(n_channels - 1 - i) * offset_step for i in range(n_channels)]
|
||
guide_levels_uv = [ # Levels relative to baseline offset
|
||
channel_display_scale_uv, channel_display_scale_uv * 2/3, channel_display_scale_uv * 1/3,
|
||
-channel_display_scale_uv * 1/3, -channel_display_scale_uv * 2/3, -channel_display_scale_uv
|
||
]
|
||
label_format = f"{{:+.{voltage_label_precision}f}} μV" # Format with sign
|
||
|
||
# --- 绘图循环 ---
|
||
for i in range(n_channels):
|
||
ch_name = ch_names[i]
|
||
channel_data = data_uv[i]
|
||
offset = y_offsets[i]
|
||
|
||
# 1. 信号
|
||
ax.plot(times, channel_data + offset, color=line_color, linewidth=line_width, zorder=3)
|
||
|
||
# 2. 基线 (0 uV)
|
||
ax.axhline(offset, color=zero_line_color, linestyle=zero_line_style, linewidth=zero_line_width, zorder=2, alpha=0.9)
|
||
trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
|
||
ax.text(1.001, offset, " 0 μV", transform=trans,
|
||
va='center', ha='left', fontsize=6, color=zero_line_color, zorder=2)
|
||
|
||
# 3. 辅助刻度线
|
||
for level_uv in guide_levels_uv:
|
||
ax.axhline(offset + level_uv, color=guide_line_color, linestyle=guide_line_style,
|
||
linewidth=guide_line_width, alpha=0.7, zorder=1)
|
||
scale_bar_uv_val = channel_display_scale_uv
|
||
if scale_bar_uv_val <=0: scale_bar_uv_val = abs(fixed_y_scale_uv)
|
||
if scale_bar_uv_val <=0: scale_bar_uv_val = 100.0
|
||
scale_bar_x_pos = plot_end_time - 0.02 * actual_plot_duration
|
||
scale_bar_y_center = ax.get_ylim()[1] - offset_step * 0.5
|
||
scale_bar_y_start = scale_bar_y_center - scale_bar_uv_val / 2.0
|
||
scale_bar_y_end = scale_bar_y_center + scale_bar_uv_val / 2.0
|
||
# 4. 可选:在顶部/底部辅助线旁添加电压标签
|
||
if show_channel_voltage_labels:
|
||
# Top label (+scale uV)
|
||
y_pos_top = offset + channel_display_scale_uv
|
||
|
||
ax.text( 1.001 , y_pos_top, # Position similar to '0 uV' label
|
||
label_format.format(channel_display_scale_uv), # Format with '+' sign
|
||
transform=trans,
|
||
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
|
||
|
||
y_pos_top31 = offset + channel_display_scale_uv*1/3
|
||
|
||
ax.text( 1.001, y_pos_top31, # Position similar to '0 uV' label
|
||
label_format.format(channel_display_scale_uv*1/3), # Format with '+' sign
|
||
transform=trans,
|
||
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
|
||
y_pos_top32 = offset + channel_display_scale_uv*2/3
|
||
|
||
ax.text( 1.001, y_pos_top32, # Position similar to '0 uV' label
|
||
label_format.format(channel_display_scale_uv*2/3), # Format with '+' sign
|
||
transform=trans,
|
||
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
|
||
|
||
# Bottom label (-scale uV)
|
||
y_pos_bottom = offset - channel_display_scale_uv
|
||
ax.text( 1.001 , y_pos_bottom, # Position similar to '0 uV' label
|
||
label_format.format(-channel_display_scale_uv), # Format with '-' sign
|
||
transform=trans,
|
||
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
|
||
y_pos_topf31 = offset -channel_display_scale_uv*1/3
|
||
|
||
ax.text( 1.001, y_pos_topf31, # Position similar to '0 uV' label
|
||
label_format.format(-channel_display_scale_uv*1/3), # Format with '+' sign
|
||
transform=trans,
|
||
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
|
||
y_pos_topf32 = offset -channel_display_scale_uv*2/3
|
||
|
||
ax.text( 1.001, y_pos_topf32, # Position similar to '0 uV' label
|
||
label_format.format(-channel_display_scale_uv*2/3), # Format with '+' sign
|
||
transform=trans,
|
||
va='center', ha='left', fontsize=6, color=channel_voltage_label_color, zorder=2)
|
||
|
||
|
||
# --- X轴和网格线 ---
|
||
ax.set_xlim(plot_start_time, plot_end_time)
|
||
ax.set_xlabel("Time (s)", fontsize=9)
|
||
ax.tick_params(axis='x', labelsize=7)
|
||
|
||
if actual_plot_duration > 0 and x_tick_interval is not None and x_tick_interval > 0:
|
||
# --- Major Grid Lines ---
|
||
# Set major ticks location (e.g., every 1 second if x_tick_interval is 0.1)
|
||
major_locator = ticker.MultipleLocator(base=x_tick_interval * 10)
|
||
ax.xaxis.set_major_locator(major_locator)
|
||
# Enable and style the major grid
|
||
ax.xaxis.grid(True, which='major', linestyle='-', linewidth=0.5, color='darkgrey', alpha=0.7, zorder=0)
|
||
|
||
# --- Minor Grid Lines ---
|
||
# Set minor ticks location (e.g., every 0.1 seconds if x_tick_interval is 0.1)
|
||
minor_locator = ticker.MultipleLocator(base=x_tick_interval)
|
||
ax.xaxis.set_minor_locator(minor_locator)
|
||
# Enable and style the minor grid (use a different style)
|
||
ax.xaxis.grid(True, which='minor', linestyle='-', linewidth=0.5, color='lightgray', alpha=0.7, zorder=0) # Example style
|
||
|
||
# Ensure Y-axis grid remains off
|
||
ax.yaxis.grid(False)
|
||
else:
|
||
# Disable both X and Y grids if condition not met
|
||
ax.xaxis.grid(False)
|
||
ax.yaxis.grid(False)
|
||
|
||
# --- Y轴配置 ---
|
||
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
|
||
ax.spines['left'].set_visible(False); ax.spines['bottom'].set_visible(True)
|
||
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=True)
|
||
ax.set_yticks(y_offsets)
|
||
ax.set_yticklabels(ch_names, fontsize=8, va='center')
|
||
y_min_limit = y_offsets[-1] + min(guide_levels_uv) - channel_display_scale_uv * 0.2
|
||
y_max_limit = y_offsets[0] + max(guide_levels_uv) + channel_display_scale_uv * 0.2
|
||
if y_max_limit <= y_min_limit: y_max_limit = y_min_limit + offset_step
|
||
ax.set_ylim(y_min_limit, y_max_limit)
|
||
|
||
# --- 标题和布局 ---
|
||
fig.suptitle(title, fontsize=12)
|
||
# Potentially need more left margin for the added voltage labels
|
||
plt.tight_layout() # Increased left margin slightly more
|
||
|
||
# --- 保存图像 ---
|
||
try:
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
fig.savefig(save_path)
|
||
plt.close(fig)
|
||
return True
|
||
except Exception as e:
|
||
print(f" 错误: 保存图像到 '{save_path}' 时失败: {e}")
|
||
plt.close(fig)
|
||
return False
|
||
|
||
|
||
# --- 单个文件夹处理函数 ---
|
||
def process_single_folder(input_folder, output_folder_base, args):
|
||
"""
|
||
处理单个文件夹中的 EEG 文件和注释, 生成图像和元数据。
|
||
对于 'all_channels' 类型的图像,会按 CHANNELS_PER_BATCH_IMAGE 分批生成。
|
||
"""
|
||
folder_name = os.path.basename(input_folder)
|
||
# 更新输出子目录名称 (可选)
|
||
output_dirs = {
|
||
"no_hl_all": os.path.join(output_folder_base, OUTPUT_SUBDIR_NO_HIGHLIGHT_ALL),
|
||
}
|
||
# 创建目录 ()
|
||
if not args.disable_no_hl_all: os.makedirs(output_dirs["no_hl_all"], exist_ok=True)
|
||
|
||
folder_metadata_list = []
|
||
folder_stats = Counter({label: 0 for label in all_possible_labels})
|
||
processed_files_count = 0
|
||
generated_images_count = 0 # 总图像文件数
|
||
|
||
edf_files = sorted(glob.glob(os.path.join(input_folder, '*.edf')))
|
||
if not edf_files: return [], folder_stats, 0, 0
|
||
|
||
for edf_path in tqdm(edf_files, desc=f"Processing {folder_name}", unit="file", leave=False):
|
||
base_filename = os.path.basename(edf_path)[:-4]
|
||
rec_path = os.path.join(input_folder, base_filename + '.rec')
|
||
|
||
if not os.path.exists(rec_path):
|
||
tqdm.write(f" Warning: REC file not found for '{base_filename}.edf' in {folder_name}. Skipping pair.")
|
||
continue
|
||
|
||
# --- 1. 加载和预处理 EEG 数据 () ---
|
||
try:
|
||
# ... (加载和预处理逻辑) ...
|
||
with mne.utils.use_log_level('WARNING'):
|
||
raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
|
||
print("采样频率:", raw.info['sfreq'], "Hz") # 打印采样频率
|
||
print("信息:", raw.info) # 打印采样频率
|
||
|
||
raw.filter(l_freq=0.1, h_freq=75.0)
|
||
raw.notch_filter(50.0) # Notch filter for power line noise
|
||
raw.resample(100, n_jobs=5)
|
||
raw.rename_channels(lambda s: s.replace('EEG ', '').replace('-REF', '').replace('.', '').upper(), verbose=False)
|
||
anodes_std = [ch.replace('EEG ', '').replace('-REF', '').replace('.', '').upper() for ch in anodes]
|
||
cathodes_std = [ch.replace('EEG ', '').replace('-REF', '').replace('.', '').upper() for ch in cathodes]
|
||
missing_channels = set(anodes_std + cathodes_std) - set(raw.ch_names)
|
||
if missing_channels:
|
||
tqdm.write(f" Error: EDF {base_filename} missing required TCP channels: {missing_channels}. Skipping.")
|
||
continue
|
||
with mne.utils.use_log_level('ERROR'):
|
||
try: raw.set_montage(mne.channels.make_standard_montage('standard_1020'), on_missing='warn', verbose=False)
|
||
except Exception: pass # Ignore if montage fails
|
||
raw_tcp = mne.set_bipolar_reference(raw, anode=anodes_std, cathode=cathodes_std, ch_name=ch_names_tcp, copy=True, verbose=False)
|
||
raw_tcp.pick_channels(ch_names_tcp) # Ensure only TCP channels remain
|
||
except Exception as e:
|
||
tqdm.write(f" Error: Failed to load/preprocess EDF '{base_filename}' in {folder_name}: {e}. Skipping.")
|
||
continue
|
||
|
||
# --- 2. 解析整个 REC 文件 () ---
|
||
all_annotations_in_file = []
|
||
current_file_labels = []
|
||
try:
|
||
# ... (REC 解析逻辑) ...
|
||
with open(rec_path, 'r') as f_rec:
|
||
# print(rec_path)
|
||
for line_num, line in enumerate(f_rec):
|
||
line = line.strip(); parts = line.split(',')
|
||
if not line or len(parts) != 4: continue
|
||
try:
|
||
ch_idx = int(parts[0].strip()); start = float(parts[1].strip()); stop = float(parts[2].strip()); lbl_code = int(parts[3].strip())
|
||
duration = stop - start
|
||
# Validate channel index against *current* raw_tcp channels (should be ch_names_tcp)
|
||
if not (0 <= ch_idx < len(raw_tcp.ch_names)) or duration <= 0 or lbl_code not in rec_label_map: continue
|
||
label_desc = rec_label_map[lbl_code]
|
||
# Use the actual channel name from raw_tcp after picking
|
||
actual_ch_name = raw_tcp.ch_names[ch_idx]
|
||
all_annotations_in_file.append({
|
||
"ch_index": ch_idx, # Index within the final TCP montage
|
||
"ch_name": actual_ch_name, # Actual name in raw_tcp
|
||
"start": start, "stop": stop, "duration": duration,
|
||
"label_code": lbl_code, "label": label_desc, "rec_line": line_num + 1
|
||
})
|
||
current_file_labels.append(label_desc)
|
||
except ValueError: continue # Ignore lines with non-numeric values where expected
|
||
if not all_annotations_in_file:
|
||
processed_files_count += 1 # Still count as processed even if no annotations
|
||
continue # Skip image generation if no annotations
|
||
folder_stats.update(current_file_labels)
|
||
except Exception as e:
|
||
tqdm.write(f" Error: Failed to read/parse REC '{base_filename}.rec' in {folder_name}: {e}. Skipping file pair.")
|
||
continue
|
||
|
||
# --- 3. 遍历每个标注作为"触发器"来生成图像 ---
|
||
for trigger_annot in all_annotations_in_file:
|
||
start_time = trigger_annot['start']; stop_time = trigger_annot['stop']; duration = trigger_annot['duration']
|
||
annotated_channel_name = trigger_annot['ch_name']; label_description = trigger_annot['label']
|
||
# Use the index directly from the annotation (it's already relative to raw_tcp)
|
||
ch_index = trigger_annot['ch_index']
|
||
rec_line = trigger_annot['rec_line']
|
||
|
||
# --- 计算绘图时间窗口和截取数据 () ---
|
||
plot_duration = max(duration + 2 * PLOT_BUFFER_SECONDS, MIN_PLOT_DURATION)
|
||
plot_center = start_time + duration / 2.0
|
||
plot_start = max(0, plot_center - plot_duration / 2.0)
|
||
# Ensure plot_end doesn't exceed data length
|
||
plot_end = min(raw_tcp.times[-1], plot_center + plot_duration / 2.0)
|
||
actual_plot_duration = plot_end - plot_start
|
||
if actual_plot_duration < 0.1: # Skip very short or invalid intervals
|
||
tqdm.write(f" Skipping annotation at {start_time:.2f}s: plot duration too short ({actual_plot_duration:.3f}s).")
|
||
continue
|
||
try:
|
||
with mne.utils.use_log_level('ERROR'):
|
||
# Crop the already processed raw_tcp object
|
||
cropped_raw = raw_tcp.copy().crop(tmin=plot_start, tmax=plot_end, include_tmax=True, verbose=False)
|
||
data_cropped, times_cropped = cropped_raw.get_data(return_times=True)
|
||
if data_cropped.size == 0 or times_cropped.size == 0:
|
||
tqdm.write(f" Skipping annotation at {start_time:.2f}s: Cropped data is empty.")
|
||
continue
|
||
except Exception as crop_err:
|
||
tqdm.write(f" Error cropping data for annotation at {start_time:.2f}s: {crop_err}. Skipping annotation.")
|
||
continue
|
||
|
||
# --- 准备基础文件名 (不含批次信息) ---
|
||
safe_time_str = f"{start_time:.2f}".replace('.', 'p')
|
||
base_img_name_trigger = f"{base_filename}_ch{ch_index}_{annotated_channel_name}_{safe_time_str}s_{label_description}"
|
||
|
||
# --- 准备元数据条目,路径将稍后填充 ---
|
||
img_success_paths_nh_all = [] # List for this type
|
||
images_generated_this_trigger = 0 # Count images for this *specific* trigger annotation
|
||
|
||
common_plot_args = { "y_scaling_mode": Y_SCALING_MODE, "fixed_y_scale_uv": FIXED_Y_SCALE_UV, "y_padding_factor": Y_PADDING_FACTOR, "fig_width": FIG_WIDTH_INCHES, "dpi": DPI, "line_color": LINE_COLOR, "line_width": LINE_WIDTH, "zero_line_color": ZERO_LINE_COLOR, "zero_line_style": ZERO_LINE_STYLE, "zero_line_width": ZERO_LINE_WIDTH, "voltage_label_precision": VOLTAGE_LABEL_PRECISION }
|
||
|
||
# --- 生成图像 ---
|
||
n_channels_total = data_cropped.shape[0]
|
||
num_batches = math.ceil(n_channels_total / CHANNELS_PER_BATCH_IMAGE)
|
||
# print("num_batches",num_batches)
|
||
# 1. No Highlight - All Channels (Batched)
|
||
if not args.disable_no_hl_all:
|
||
for i_batch in range(num_batches):
|
||
# print("ibatch",i_batch)
|
||
start_idx = i_batch * CHANNELS_PER_BATCH_IMAGE
|
||
end_idx = min((i_batch + 1) * CHANNELS_PER_BATCH_IMAGE, n_channels_total)
|
||
data_batch = data_cropped[start_idx:end_idx, :]
|
||
ch_names_batch = cropped_raw.ch_names[start_idx:end_idx]
|
||
|
||
if data_batch.size == 0: continue # Skip empty batch
|
||
|
||
batch_suffix = f"_batch{i_batch+1}of{num_batches}" if num_batches > 1 else ""
|
||
save_path_nh_all_batch = os.path.abspath(os.path.join(output_dirs["no_hl_all"], f"{base_img_name_trigger}_noHL_ALL{batch_suffix}.png"))
|
||
|
||
# Pass the batch data and names
|
||
success_nh_a = plot_eeg_matplotlib_flexible_scale(
|
||
data=data_batch, times=times_cropped, ch_names=ch_names_batch,
|
||
title="", save_path=save_path_nh_all_batch,
|
||
fig_height_per_channel=FIG_HEIGHT_PER_CHANNEL,
|
||
**common_plot_args
|
||
)
|
||
if success_nh_a:
|
||
images_generated_this_trigger += 1
|
||
img_success_paths_nh_all.append(save_path_nh_all_batch) # Append path to list
|
||
|
||
# --- 存储元数据 (仅当为此触发器生成了至少一张图片时) ---
|
||
if images_generated_this_trigger > 0:
|
||
metadata_entry = {
|
||
"original_edf_abs": os.path.abspath(edf_path),
|
||
"original_rec_abs": os.path.abspath(rec_path),
|
||
"triggering_rec_line": rec_line,
|
||
"triggering_annotation": trigger_annot,
|
||
"plot_window_start_sec": plot_start,
|
||
"plot_window_end_sec": plot_end,
|
||
"actual_plot_duration_sec": actual_plot_duration,
|
||
# --- 使用新的键名存储路径列表或单个路径 ---
|
||
"image_paths_no_highlight_all_abs": img_success_paths_nh_all if img_success_paths_nh_all else None, # Store list or None
|
||
}
|
||
folder_metadata_list.append(metadata_entry)
|
||
generated_images_count += images_generated_this_trigger # Add to folder's total image count
|
||
|
||
# End annotation loop
|
||
|
||
processed_files_count += 1 # Increment after successfully processing file (incl. annotations)
|
||
# End file loop
|
||
|
||
# --- 保存当前文件夹的统计信息 JSON ---
|
||
folder_summary_path = os.path.join(output_folder_base, FOLDER_JSON_FILENAME)
|
||
# Include the batch size in the summary
|
||
folder_summary_data = {
|
||
"input_folder": os.path.abspath(input_folder),
|
||
"output_folder": os.path.abspath(output_folder_base),
|
||
"channels_per_batched_image": CHANNELS_PER_BATCH_IMAGE, # Add info
|
||
"files_processed": processed_files_count,
|
||
"images_generated": generated_images_count, # This now counts individual image files
|
||
"annotation_counts": dict(folder_stats),
|
||
"image_metadata": folder_metadata_list # This list contains entries per trigger-annotation
|
||
}
|
||
try:
|
||
if processed_files_count > 0 or generated_images_count > 0 :
|
||
os.makedirs(output_folder_base, exist_ok=True) # Ensure base output dir exists
|
||
with open(folder_summary_path, 'w') as f_json:
|
||
# Use default=str to handle potential non-serializable types if any sneak in
|
||
json.dump(folder_summary_data, f_json, indent=4, default=str)
|
||
except Exception as e:
|
||
tqdm.write(f" Error: Failed to save folder summary JSON '{folder_summary_path}': {e}")
|
||
|
||
return folder_metadata_list, folder_stats, processed_files_count, generated_images_count
|
||
|
||
|
||
# --- 主递归处理函数 ---
|
||
def process_directory_recursive(root_input_dir, root_output_dir, args):
|
||
"""
|
||
递归处理输入根目录下的所有子文件夹。
|
||
Passes args down to control image generation.
|
||
"""
|
||
print(f"Scanning for data folders in: {root_input_dir}...")
|
||
|
||
global_metadata_list = []
|
||
global_stats = Counter({label: 0 for label in all_possible_labels})
|
||
total_files_processed_global = 0
|
||
total_images_generated_global = 0 # Counts individual image files
|
||
|
||
data_folders_to_process = []
|
||
for subdir, dirs, files in os.walk(root_input_dir):
|
||
# Check if *any* file in the current directory ends with .edf
|
||
if any(fname.lower().endswith('.edf') for fname in files):
|
||
# Also check if a corresponding .rec file exists for at least one .edf
|
||
has_rec = False
|
||
for fname in files:
|
||
if fname.lower().endswith('.edf'):
|
||
base = fname[:-4]
|
||
if os.path.exists(os.path.join(subdir, base + '.rec')):
|
||
has_rec = True
|
||
break
|
||
if has_rec:
|
||
data_folders_to_process.append(subdir)
|
||
else:
|
||
print(f" Skipping folder (no .rec found): {subdir}")
|
||
|
||
|
||
if not data_folders_to_process:
|
||
print("No folders containing corresponding .edf and .rec files found.")
|
||
return
|
||
|
||
print(f"Found {len(data_folders_to_process)} data folder(s) with EDF/REC pairs to process.")
|
||
print(f"Outputting batched 'all channels' images with {CHANNELS_PER_BATCH_IMAGE} channels per image.") # Inform user
|
||
print(f"Image generation settings:")
|
||
print(f" - No Highlight All (Batched): {'Enabled' if not args.disable_no_hl_all else 'Disabled'}")
|
||
|
||
|
||
processed_folders_count = 0
|
||
# Wrap data_folders_to_process with tqdm for overall progress bar
|
||
for subdir in tqdm(data_folders_to_process, desc="Overall Progress", unit="folder"):
|
||
processed_folders_count += 1
|
||
# Construct output path relative to root_output_dir
|
||
relative_path = os.path.relpath(subdir, root_input_dir)
|
||
current_output_dir = os.path.join(root_output_dir, relative_path)
|
||
|
||
# <<< Pass args down >>>
|
||
folder_metadata, folder_stats, folder_files_processed, folder_images_generated = process_single_folder(
|
||
subdir, current_output_dir, args
|
||
)
|
||
|
||
global_metadata_list.extend(folder_metadata)
|
||
global_stats.update(folder_stats)
|
||
total_files_processed_global += folder_files_processed
|
||
total_images_generated_global += folder_images_generated # Accumulate individual image counts
|
||
|
||
# --- 保存全局 JSON 元数据和统计信息 ---
|
||
if global_metadata_list or processed_folders_count > 0:
|
||
global_json_path = os.path.join(root_output_dir, GLOBAL_JSON_FILENAME)
|
||
|
||
# Base structure (remains similar)
|
||
global_summary_data = {
|
||
"root_input_directory": os.path.abspath(root_input_dir),
|
||
"root_output_directory": os.path.abspath(root_output_dir),
|
||
"channels_per_batched_image": CHANNELS_PER_BATCH_IMAGE, # Add info
|
||
"total_folders_processed": processed_folders_count,
|
||
"total_edf_rec_pairs_processed": total_files_processed_global,
|
||
"total_images_generated": total_images_generated_global, # Total individual images
|
||
# Add individual counts here
|
||
}
|
||
|
||
# Add counts with "count_" prefix ()
|
||
for label, count in global_stats.items():
|
||
global_summary_data[f"count_{label}"] = count
|
||
for label in all_possible_labels:
|
||
if f"count_{label}" not in global_summary_data:
|
||
global_summary_data[f"count_{label}"] = 0
|
||
|
||
# Add the list of all metadata entries at the end ()
|
||
# Each entry in this list now potentially contains lists of image paths
|
||
global_summary_data["all_image_metadata"] = global_metadata_list
|
||
|
||
try:
|
||
os.makedirs(root_output_dir, exist_ok=True)
|
||
with open(global_json_path, 'w') as f_json:
|
||
# Use default=str for safety
|
||
json.dump(global_summary_data, f_json, indent=4, default=str)
|
||
|
||
print(f"\n--- === Global Processing Complete === ---")
|
||
print(f"Processed {processed_folders_count} data folder(s).")
|
||
print(f"Processed a total of {total_files_processed_global} EDF/REC file pairs.")
|
||
# Updated message reflects individual image files
|
||
print(f"Generated a total of {total_images_generated_global} individual image files (including batches).")
|
||
print(f"Global summary and metadata saved to: {global_json_path}")
|
||
print(f"Global annotation counts:")
|
||
for label in sorted(all_possible_labels):
|
||
print(f" - {label}: {global_stats[label]}")
|
||
|
||
except Exception as e:
|
||
print(f"\nError: Failed to save global JSON summary '{global_json_path}': {e}")
|
||
else:
|
||
print("\n--- === Global Processing Complete === ---")
|
||
print("No data folders were found or processed.")
|
||
|
||
|
||
# --- 执行主程序 ---
|
||
if __name__ == "__main__":
|
||
# --- Argument Parsing () ---
|
||
parser = argparse.ArgumentParser(description="Generate cropped EEG images from EDF/REC pairs in a directory structure, optionally batching multi-channel plots.")
|
||
parser.add_argument('-i', '--input-dir', type=str, default=ROOT_INPUT_DIR,
|
||
help=f"Root directory containing the input data folders (e.g., eval/, train/). Default: {ROOT_INPUT_DIR}")
|
||
parser.add_argument('-o', '--output-dir', type=str, default=ROOT_OUTPUT_DIR,
|
||
help=f"Root directory where output images and JSON summaries will be saved. Default: {ROOT_OUTPUT_DIR}")
|
||
|
||
# Flags to disable specific image types (default is Enabled)
|
||
# Defaults are changed to typically enable only the batched 'no highlight all'
|
||
parser.add_argument('--disable-no-hl-all', action='store_true', default=False,
|
||
help="Disable generation of 'no highlight, all channels' (batched) images.")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Use directories from args
|
||
root_input_dir = args.input_dir
|
||
root_output_dir = args.output_dir
|
||
|
||
if not os.path.isdir(root_input_dir):
|
||
print(f"错误:输入根目录 '{root_input_dir}' 不存在或不是一个目录。请检查 --input-dir 参数或脚本中的 ROOT_INPUT_DIR。")
|
||
# Exit if input directory is invalid
|
||
sys.exit(1)
|
||
else:
|
||
# Ensure base output directory exists before starting
|
||
os.makedirs(root_output_dir, exist_ok=True)
|
||
# <<< Pass args to the main processing function >>>
|
||
process_directory_recursive(root_input_dir, root_output_dir, args) |