Giter VIP home page Giter VIP logo

superbrucejia / eeg-motor-imagery-classification-cnns-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
191.0 4.0 46.0 58 KB

EEG Motor Imagery Tasks Classification (by Channels) via Convolutional Neural Networks (CNNs) based on TensorFlow

Home Page: https://iopscience.iop.org/article/10.1088/1741-2552/ab4af6/meta

Python 79.18% MATLAB 20.82%
eeg-data motor-imagery-tasks cnns convolutional-neural-networks esi tensorflow python eeg eeg-signals eeg-analysis

eeg-motor-imagery-classification-cnns-tensorflow's Introduction

EEG Motor Imagery Signals (Tasks) Classification via Convolutional Neural Networks (CNN)

Author: Shuyue Jia and Lu Zhou, School of Automation Engineering, Northeast Electric Power University, Jilin, China.

Date: December of 2018

Download Paper

A Novel Approach of Decoding EEG Four-class Motor Imagery Tasks via Scout ESI and CNN

NOTICE: The method in our paper is EEG source imaging (ESI) + Morlet wavelet joint time-frequency analysis (JTFA) + Convolutional Neural Networks (CNNs). The raw data has been processed using the Matlab Toolkit Brainstorm. My job is using CNNs to classify the EEG data after the ESI + JTFA process. The Dataset (.mat Files) preprocessed via the ESI + JTFA process can be found via the Shared Google Drive. The corresponding preprocessed Excel files, trained checkpoints, and evaluation results can be downloaded from the Shared Google Drive.

Meanwhile, the codes in this repository are based on the raw EEG data without the ESI and JTFA process, and can also achieve a good result. The main CNNs Tensorflow framework codes in the "MI_Proposed_CNNs_Architecture.py" are the same for both of the works.


Overall Framework:

Project1

Proposed CNNs Architecture:

Project1

Installation and Usage

  1. Python file: PhysioNet_MI_Dataset/MIND_Get_EDF.py

    --- download all the EEG Motor Movement/Imagery Dataset .edf files from here!

    (Under Any Python Environment) $ python MIND_Get_EDF.py
    
  2. Python file: Read_Raw_Data_Save_Into_Matlab_Files.py

    --- Read the edf Raw data of different channels and save them into matlab .m files

    --- At this stage, the Python file must be processed under a Python 2 environment (I recommend to use Python 2.7 version).

    (Under Python 2.7 Environment) $ python Read_Raw_Data_Save_Into_Matlab_Files.py
    
  3. Matlab file: Saved_Matlab_Data/Preprocessing_Raw_Data.m

    --- Pre-process the dataset (Data Normalization mainly) and save matlab .m files into Excel .xlsx Files

  4. Python file: MI_Proposed_CNNs_Architecture.py

    --- the proposed CNNs architecture

    --- based on TensorFlow 1.12.0 with CUDA 9.0 or TensorFlow 1.13.1 with CUDA 10.0

    --- The trained results are saved in the Tensorboard

    --- Open the Tensorboard and save the results into Excel .csv files

    --- Draw the graphs using Matlab or Origin

    (Under Python 3.6 Environment) $ python MI_Proposed_CNNs_Architecture.py
    

Structure of the code

At the root of the project, you will see:

├── PhysioNet_MI_Dataset
|  └── MIND_Get_EDF.py
├── Read_Raw_Data_Save_Into_Matlab_Files.py
├── Saved_Matlab_Data
|  └── Preprocessing_Raw_Data.m
├── MI_Proposed_CNNs_Architecture.py
├── electrode_positions.txt

Citation

If you find our work useful in your research, please consider citing it in your publications. We provide a BibTeX entry below.

@article{hou2020novel,
	title     = {A Novel Approach of Decoding EEG Four-class Motor Imagery Tasks via Scout ESI and CNN},
	author    = {Hou, Yimin and Zhou, Lu and Jia, Shuyue and Lun, Xiangmin},
	journal   = {Journal of Neural Engineering},
	volume    = {17},
	number    = {1},
	pages     = {016048},
	year      = {Feb. 2020},
	publisher = {IOP Publishing}
}

Acknowledgment

We are very grateful to Prof. Yimin Hou due to his friendly guidance, and the research paper would not have happened without him.

eeg-motor-imagery-classification-cnns-tensorflow's People

Contributors

superbrucejia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

eeg-motor-imagery-classification-cnns-tensorflow's Issues

Data processing before entering CNN

Hello, I would like to ask if you split the 23-by-640 time-frequency diagram obtained after the combined time-frequency analysis into 23 32-by-20 recompositions and put them into CNN?

need the lead field matrix and scout indices

Hi

I am reading the paper now.

Can you provide the lead field matrix and scout indices used in the paper? That would be helpful. I think they could be exported from the Brainstorm.

Thank you.

好像有用到测试集的bug

数据预处理那里 我看到你们方法是把所有的通道提取出来,假设下面这种情况:
训练集T1,测试集T2,通道假设为三通道C1、C2、C3。
显然,在训练过程中T2的三个通道的数据均未知。
但是根据您的方法,数据预处理的时候把通道提出来再合并成一个矩阵,数据的最小单位不是一个事件,而是一个事件的一个通道。
随着数据量的增加,用到测试集的通道应该是必然事件。
您的Matlab代码中处理后的105人数据的shape是149940x640,640是160Hz4s, 149940=105人17通道*84事件。
在这基础上打乱后分训练集测试集,就会用到测试事件的通道信息。
以上是我这个大四学生的猜测,希望您能解答我的疑惑

python MI_Proposed_CNNs_Architecture.py 执行错误

Traceback (most recent call last):
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1334, in _do_call
return fn(*args)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1319, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1407, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InternalError: Blas GEMM launch failed : a.shape=(128, 4), b.shape=(512, 4), m=128, n=512, k=4
[[{{node Train_Optimizer/gradients/Output_Layer/prediction/MatMul_grad/MatMul}} = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](Train_Optimizer/gradients/Output_Layer/prediction/add_grad/tuple/control_dependency, Output_Layer/W_fc2/Variable/read)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "MI_Proposed_CNNs_Architecture.py", line 582, in
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.50})
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1152, in _run
feed_dict_tensor, options, run_metadata)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1328, in _do_run
run_metadata)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1348, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Blas GEMM launch failed : a.shape=(128, 4), b.shape=(512, 4), m=128, n=512, k=4
[[node Train_Optimizer/gradients/Output_Layer/prediction/MatMul_grad/MatMul (defined at MI_Proposed_CNNs_Architecture.py:301) = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](Train_Optimizer/gradients/Output_Layer/prediction/add_grad/tuple/control_dependency, Output_Layer/W_fc2/Variable/read)]]

Caused by op 'Train_Optimizer/gradients/Output_Layer/prediction/MatMul_grad/MatMul', defined at:
File "MI_Proposed_CNNs_Architecture.py", line 301, in
train_step = tf.train.AdamOptimizer(1e-5).minimize(loss)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 400, in minimize
grad_loss=grad_loss)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 519, in compute_gradients
colocate_gradients_with_ops=colocate_gradients_with_ops)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 630, in gradients
gate_gradients, aggregation_method, stop_gradients)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 814, in _GradientsHelper
lambda: grad_fn(op, *out_grads))
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 408, in _MaybeCompile
return grad_fn() # Exit early
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 814, in
lambda: grad_fn(op, *out_grads))
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py", line 1130, in _MatMulGrad
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 4560, in mat_mul
name=name)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
op_def=op_def)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1770, in init
self._traceback = tf_stack.extract_stack()

...which was originally created as op 'Output_Layer/prediction/MatMul', defined at:
File "MI_Proposed_CNNs_Architecture.py", line 290, in
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py", line 2057, in matmul
a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 4560, in mat_mul
name=name)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
op_def=op_def)
File "/root/miniconda3/envs/oldMotor/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1770, in init
self._traceback = tf_stack.extract_stack()

InternalError (see above for traceback): Blas GEMM launch failed : a.shape=(128, 4), b.shape=(512, 4), m=128, n=512, k=4
[[node Train_Optimizer/gradients/Output_Layer/prediction/MatMul_grad/MatMul (defined at MI_Proposed_CNNs_Architecture.py:301) = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](Train_Optimizer/gradients/Output_Layer/prediction/add_grad/tuple/control_dependency, Output_Layer/W_fc2/Variable/read)]]

作者您好,我是一名大三的学生,最近正在复现您的论文寻找灵感。但是当我运行(Under Python 3.6 Environment) $ python MI_Proposed_CNNs_Architecture.py 时候遇到了以上错误,我查询了很多资料都没有结果,目前最大的可能性是tensorflow版本和cuda版本不匹配,但是我不确定这是否正确。

我的电脑配置如下:
NVIDIA-SMI 535.146.02 Driver Version: 535.146.02 CUDA Version: 12.2
NVIDIA GeForce RTX 4090 显存24G

conda 环境如下: 这个conda 环境是运行在 python 3.6.13 下

absl-py 0.15.0
astor 0.8.1
certifi 2021.5.30
coverage 5.5
Cython 0.29.24
dataclasses 0.8
et-xmlfile 1.1.0
gast 0.5.3
grpcio 1.36.1
h5py 2.10.0
importlib-metadata 4.8.1
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
Markdown 3.3.4
mkl-fft 1.3.0
mkl-random 1.1.1
mkl-service 2.3.0
numpy 1.19.2
openpyxl 3.1.2
pandas 1.1.5
pip 20.0.2
protobuf 3.17.2
python-dateutil 2.9.0.post0
pytz 2024.1
scipy 1.5.2
setuptools 36.4.0
six 1.16.0
tensorboard 1.12.2
tensorflow 1.12.0
termcolor 1.1.0
typing-extensions 4.1.1
Werkzeug 2.0.3
wheel 0.37.1
xlrd 1.2.0
zipp 3.6.0

original data question

Hello! May I ask a question about original data? Why the original data is divided into two files(.edf and .edf.event)?WhenI load it into Matlab,the types were wrong. why?

IndexError: list index out of range

Dear @SuperBruceJia i have tried both your suggestion and wei suggestion but still this error persists i will mention the error below , please help me
Start to save the File!
Traceback (most recent call last):
File "Read_Raw_Data_Save_Into_Matlab_Files.py", line 209, in
X_105_C5, y_105_C5 = load_raw_data(electrodes=electrodes, subject=subject, num_classes=nclasses)
File "Read_Raw_Data_Save_Into_Matlab_Files.py", line 200, in load_raw_data
return np.array(trials, dtype=np.float64).reshape((len(trials),) + trials[0].shape + (1,)), np.array(labels, dtype=np.float64)
IndexError: list index out of range

ValueError: need at least one array to stack

hi
i try to run your code but it give me this error :ValueError: need at least one array to stack and i am sure the path is the correct path for the data set can you help me

ERROR: Read_Raw_Data_Save_Into_Matlab_Files.py

Hello, I'm running the code "Read_Raw_Data_Save_Into_Matlab_Files.py" under the environment with python version= 2.7.18, but I always get an error shows below, is there anything wrong? Thanks!

Traceback (most recent call last):
File "Read_Raw_Data_Save_Into_Matlab_Files.py", line 209, in
X_105_C5, y_105_C5 = load_raw_data(electrodes=electrodes, subject=subject, num_classes=nclasses)
File "Read_Raw_Data_Save_Into_Matlab_Files.py", line 200, in load_raw_data
return np.array(trials, dtype=np.float64).reshape((len(trials),) + trials[0].shape + (1,)), np.array(labels, dtype=np.float64)
IndexError: list index out of range

Incompatible shapes: [64,4] vs. [10,4]

Hello Bruce
I am trying to run the code MI_Proposed_CNNs_Architecture.py and I am having a problem at this line

sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: keep_rate})
It gives me this error

Exception has occurred: InvalidArgumentError
Incompatible shapes: [64,4] vs. [10,4]
	 [[node sub_5 (defined at i:\IBM\CNN\Loss_Function\Loss.py:32) ]]

Errors may have originated from an input operation.
Input Source operations connected to node sub_5:
 Placeholder_1 (defined at i:\VBOX Shared Folder\IBM TUH Repo\CNN\main-CNN.py:58)	
 Softmax (defined at i:\IBM\CNN\Network\CNN.py:89)

Original stack trace for 'sub_5':
  File "C:\Users\******\Anaconda3\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Users\******\Anaconda3\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\******\.vscode\extensions\ms-python.python-2021.5.842923320\pythonFiles\lib\python\debugpy\__main__.py", line 45, in <module>
    cli.main()
  File "c:\Users\******\.vscode\extensions\ms-python.python-2021.5.842923320\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 444, in main
    run()
  File "c:\Users\******\.vscode\extensions\ms-python.python-2021.5.842923320\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "C:\Users\******\Anaconda3\lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "C:\Users\******\Anaconda3\lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "C:\Users\******\Anaconda3\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "i:\IBM\CNN\main-CNN.py", line 65, in <module>
    loss = loss(y=y, prediction=prediction, l2_norm=True)
  File "i:\IBM\CNN\Loss_Function\Loss.py", line 32, in loss
    model_loss = tf.reduce_mean(tf.square(y - prediction))
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1164, in binary_op_wrapper
    return func(x, y, name=name)
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py", line 561, in subtract
    return gen_math_ops.sub(x, y, name)
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 10316, in sub
    "Sub", x=x, y=y, name=name)
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 750, in _apply_op_helper
    attrs=attr_protos, op_def=op_def)
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3536, in _create_op_internal
    op_def=op_def)
  File "C:\Users\******\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1990, in __init__
    self._traceback = tf_stack.extract_stack()

During handling of the above exception, another exception occurred:

  File "i:\IBM\CNN\main-CNN.py", line 97, in <module>
    sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: keep_rate})

batch_xs.shape : (64, 640)
batch_ys.shape : (64, 4)
Where did the object which size is [10,4] came from?!

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.