batch_size = 2
get_single_xy_timeout = 10
train_timeout = 3
n_epochs = 20
# dummy function simulating loading, preprocessing
# and other operations needed for the X and Y
def get_single_xy():
time.sleep(get_single_xy_timeout)
x = np.random.rand(2,2)
y = random.random()
return x, y
# placeholders
X = tf.placeholder(tf.float32, [2, 2])
Y = tf.placeholder(tf.float32)
# create the query
queue = tf.FIFOQueue(
capacity=15,
dtypes=[tf.float32, tf.float32],
shapes=[[2,2], []],
)
# Enqueues (add) one element to this queue.
enqueue_op = queue.enqueue([X, Y])
# Dequeues (remove) one element from this queue.
dequeue_op = queue.dequeue()
# Dequeues and concatenates `n` elements from this queue.
Xs, Ys = queue.dequeue_many(n=batch_size)
# dummy train operation
train_op = tf.reduce_mean(Xs) * Ys
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)
# A coordinator for threads
coord = tf.train.Coordinator()
def enqueue_thread():
# Context manager to request stop when an Exception is raised.
with coord.stop_on_exception():
while not coord.should_stop():
x, y = get_single_xy()
sess.run(enqueue_op, feed_dict={X: x, Y: y})
available_threads = 5
for _ in range(available_threads):
threading.Thread(target=enqueue_thread).start()
for epoch in range(n_epochs):
start_time = time.time()
sess.run(train_op)
time.sleep(train_timeout)
print('Time:', time.time() - start_time)
# https://github.com/tensorflow/tensorflow/issues/2514#issuecomment-221934925
import tensorflow as tf
from threading import Thread
batch_size = 2
get_single_xy_timeout = 10
train_timeout = 3
n_steps = 5
n_threads = 3
q_capacity = 3
# dummy function simulating loading, preprocessing
# and other operations needed for the X and Y
class Dataset():
def __init__(self):
self.counter = 0
self.x_dtype = tf.float32
self.x_shape = [2,2]
self.y_dtype = tf.float32
self.y_shape = []
self.dtypes = [self.x_dtype, self.y_dtype]
self.shapes = [self.x_shape, self.y_shape]
def get_single_xy(self, train=True):
# can have a state
self.counter += 1
time.sleep(get_single_xy_timeout)
# types of x, y have to match queue
x = np.random.rand(2,2)
if train:
y = np.random.rand()
else:
y = 100000.
return [x, y]
dataset = Dataset()
q_train = tf.FIFOQueue(
capacity=q_capacity,
dtypes=dataset.dtypes,
shapes=dataset.shapes
)
X_train = tf.placeholder(dataset.x_dtype, dataset.x_shape)
Y_train = tf.placeholder(dataset.y_dtype, dataset.y_shape)
enqueue_train = q_train.enqueue([X_train, Y_train])
q_test = tf.FIFOQueue(
capacity=q_capacity,
dtypes=dataset.dtypes,
shapes=dataset.shapes
)
X_test = tf.placeholder(dataset.x_dtype, dataset.x_shape)
Y_test = tf.placeholder(dataset.y_dtype, dataset.y_shape)
enqueue_test = q_test.enqueue([X_test, Y_test])
q_selector = tf.placeholder(tf.int32, [])
q = tf.QueueBase.from_list(q_selector, [q_train, q_test])
dequeue_op = q.dequeue()
X, Y = q.dequeue_many(n=batch_size)
# dummy train operation
train_op = tf.reduce_mean(tf.reduce_mean(X) * Y)
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)
# Create a coordinator, launch the queue runner threads.
coordinator = tf.train.Coordinator()
def train_pusher():
with coordinator.stop_on_exception():
while not coordinator.should_stop():
x, y = dataset.get_single_xy(train=True)
sess.run(enqueue_train, { X_train: x, Y_train: y })
def test_pusher():
with coordinator.stop_on_exception():
while not coordinator.should_stop():
x, y = dataset.get_single_xy(train=False)
sess.run(enqueue_test, { X_test: x, Y_test: y })
threads = [Thread(target=train_pusher) for i in range(n_threads)]
threads += [Thread(target=test_pusher) for i in range(n_threads)]
[t.start() for t in threads]
try:
for step in range(n_steps):
# if something not right stop
if coordinator.should_stop():
break
start_time = time.time()
result = sess.run(train_op, {q_selector: 0})
time.sleep(train_timeout)
print('Result', result, 'Time:', time.time() - start_time)
for step in range(n_steps):
start_time = time.time()
result = sess.run(train_op, {q_selector: 1})
time.sleep(train_timeout)
print('Result', result, 'Time:', time.time() - start_time)
except Exception as e:
# Report exceptions to the coordinator.
coordinator.request_stop(e)
finally:
# Terminate as usual. It is innocuous to request stop twice.
coordinator.request_stop()
[t.join(0) for t in threads]
sess.close()
import tensorflow as tf
batch_size = 2
get_single_xy_timeout = 10
train_timeout = 3
n_epochs = 5
n_threads = 5
# dummy function simulating loading, preprocessing
# and other operations needed for the X and Y
class DataFetcher():
def __init__(self):
self.counter = 0
def get_single_xy(self):
# can have a state
self.counter += 1
print(self.counter)
time.sleep(get_single_xy_timeout)
# types of x, y have to match queue
x = np.random.rand(2,2).astype(np.float32)
y = np.random.rand(1).astype(np.float32)[0]
return [x, y]
data = DataFetcher()
# create the query
queue = tf.FIFOQueue(
capacity=15,
dtypes=[tf.float32, tf.float32],
shapes=[[2,2], []],
)
python_data_op = tf.py_func(data.get_single_xy, inp=[], Tout=[tf.float32, tf.float32])
# Enqueues (add) one element to this queue.
enqueue_op = queue.enqueue(python_data_op)
# Dequeues (remove) one element from this queue.
dequeue_op = queue.dequeue()
# Dequeues and concatenates `n` elements from this queue.
X, Y = queue.dequeue_many(n=batch_size)
# dummy train operation
train_op = tf.reduce_mean(tf.reduce_mean(X) * Y)
# Create a queue runner that will run 4 threads in parallel to enqueue examples
qr = tf.train.QueueRunner(queue, [enqueue_op] * n_threads)
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
threads = qr.create_threads(sess, coord=coord, start=True)
try:
for step in range(n_epochs):
if coord.should_stop():
break
# inside the train loop
start_time = time.time()
result = sess.run(train_op)
time.sleep(train_timeout)
print('Result', result, 'Time:', time.time() - start_time)
except Exception as e:
# Report exceptions to the coordinator.
coord.request_stop(e)
finally:
# Terminate as usual. It is innocuous to request stop twice.
coord.request_stop()
coord.join(threads)
sess.close()