Comments (3)
you can use tf.train.MonitoredTrainingSession instead of tf.Session , and global_variables_initializer
is not necessary when using MonitoredTrainingSession, you can refer https://github.com/alibaba/FastNN/blob/73b70c633117ccff4f1a270f461bacb96e0fc4ee/resnet/resnet_dp.py#L67
from easyparallellibrary.
Thanks for your reply! I modify the code as follows:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import epl
import os
def conv_bn_relu(inputs, filters, kernel_size, stride, training):
conv = tf.layers.conv2d(inputs, filters, kernel_size, strides=stride, padding='SAME', use_bias=False)
bn = tf.layers.batch_normalization(conv, training=training)
relu = tf.nn.relu(bn)
return relu
def bottleneck_block(inputs, filters, stride, training):
shortcut = inputs
out = conv_bn_relu(inputs, filters, 1, 1, training)
out = conv_bn_relu(out, filters, 3, stride, training)
out = conv_bn_relu(out, 4 * filters, 1, 1, training)
if stride != 1 or inputs.get_shape().as_list()[-1] != 4 * filters:
shortcut = tf.layers.conv2d(inputs, 4 * filters, 1, strides=stride, padding='SAME', use_bias=False)
shortcut = tf.layers.batch_normalization(shortcut, training=training)
out = tf.add(out, shortcut)
return out
def resnet50(inputs, training):
out = conv_bn_relu(inputs, 64, 3, 1, training)
out = bottleneck_block(out, 64, 1, training)
out = bottleneck_block(out, 128, 2, training)
out = bottleneck_block(out, 256, 2, training)
out = bottleneck_block(out, 512, 2, training)
out = tf.layers.average_pooling2d(out, 4, 1)
out = tf.layers.flatten(out)
out = tf.layers.dense(out, 10)
return out
def run_model():
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train, X_test = X_train.astype(np.float32) / 255.0, X_test.astype(np.float32) / 255.0
y_train, y_test = y_train.astype(np.int32), y_test.astype(np.int32)
images = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='images')
labels = tf.placeholder(tf.int32, shape=(None), name='labels')
is_training = tf.placeholder(tf.bool, name='is_training')
logits = resnet50(images, is_training)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(0.001)
train_op = optimizer.minimize(loss, global_step=global_step)
batch_size = 128
n_epochs = 100
hooks = [tf.train.StopAtStepHook(last_step=n_epochs * len(X_train) // batch_size)]
def get_batch(data, labels, batch_size):
idx = np.random.choice(np.arange(len(data)), batch_size, replace=False)
return data[idx], labels[idx].flatten()
with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
while not sess.should_stop():
batch_images, batch_labels = get_batch(X_train, y_train, batch_size)
_, train_loss, step = sess.run(
[train_op, loss, global_step],
feed_dict={images: batch_images, labels: batch_labels, is_training: True}
)
if step % 100 == 0:
print(f"Step {step}, Loss: {train_loss:.4f}")
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
config_json = {}
epl.init(epl.Config(config_json))
print(epl.Env.get().cluster.gpu_num_per_worker)
if epl.Env.get().cluster.gpu_num_per_worker > 1:
# Avoid NCCL hang.
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
epl.set_default_strategy(epl.replicate(device_count=1))
run_model()
However, I am confronted with the following issue:
Traceback (most recent call last):
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
return fn(*args)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1350, in _run_fn
target_list, run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1443, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: From /job:worker/replica:0/task:0:
You must feed a value for placeholder tensor 'EPL_REPLICA_1/labels' with dtype int32
[[{{node EPL_REPLICA_1/labels}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "resnet50_split4.py", line 89, in
run_model()
File "resnet50_split4.py", line 74, in run_model
feed_dict={images: batch_images, labels: batch_labels, is_training: True}
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 754, in run
run_metadata=run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1259, in run
run_metadata=run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1360, in run
raise six.reraise(*original_exc_info)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/six.py", line 719, in reraise
raise value
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1345, in run
return self._sess.run(*args, **kwargs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1418, in run
run_metadata=run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1176, in run
return self._sess.run(*args, **kwargs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 464, in run
outputs = fn(self, actual_fetches, feed_dict, options, run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
run_metadata_ptr)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1180, in _run
feed_dict_tensor, options, run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: From /job:worker/replica:0/task:0:
You must feed a value for placeholder tensor 'EPL_REPLICA_1/labels' with dtype int32
[[node EPL_REPLICA_1/labels (defined at /users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]
Original stack trace for 'EPL_REPLICA_1/labels':
File "resnet50_split4.py", line 89, in
run_model()
File "resnet50_split4.py", line 69, in run_model
with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1014, in init
stop_grace_period_secs=stop_grace_period_secs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 319, in init
res = fn(self, *args, **kwargs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 725, in init
self._sess = _RecoverableSession(self._coordinated_creator)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1207, in init
_WrappedSession.init(self, self._create_session())
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1212, in _create_session
return self._sess_creator.create_session()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 878, in create_session
self.tf_sess = self._session_creator.create_session()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 638, in create_session
self._scaffold.finalize()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 273, in finalize
fn(self)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 239, in finalize
ops.get_default_graph().finalize()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 261, in finalize
Parallel.get().do_parallelism()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/parallel.py", line 223, in do_parallelism
self.transformer.replicas_clone()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/graph_editor.py", line 427, in replicas_clone
self._forward_clone()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/graph_editor.py", line 343, in _forward_clone
target_device)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/ops.py", line 237, in node_clone_for_replicas
op_def=op_def)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 1748, in init
self._traceback = tf_stack.extract_stack()
Could you give me a hand? Thank you very much!
from easyparallellibrary.
you should replace get_batch with tf.data.Dataset
from easyparallellibrary.
Related Issues (9)
- 训练时,除chief worker外,其余worker在每次save checkpoint 后 step归0,且在第二次save checkpoint 后 整个进程卡死 HOT 1
- Gradient Checkpoint with auto type got a TypeError HOT 1
- Problem of Data Parallel Model, program didn't end when reached global step HOT 1
- DistributedDense只支持按照列切分吗? HOT 1
- DingTalk QR code is outdated HOT 2
- 2台服务器分布式跑example中的resnet_split.py遇到无限等待的情况 HOT 4
- 2机2卡实验NCCL报错 HOT 1
- epl单机单卡和单机多卡训练step如何理解 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from easyparallellibrary.