post2web
2/9/2017 - 1:37 AM

tensorflow threads with QueueRunner

batch_size = 2
get_single_xy_timeout = 10
train_timeout = 3
n_epochs = 20

# dummy function simulating loading, preprocessing
# and other operations needed for the X and Y
def get_single_xy():
    time.sleep(get_single_xy_timeout)
    x = np.random.rand(2,2)
    y = random.random()
    return x, y

# placeholders
X = tf.placeholder(tf.float32, [2, 2])
Y = tf.placeholder(tf.float32)

# create the query
queue = tf.FIFOQueue(
    capacity=15,
    dtypes=[tf.float32, tf.float32],
    shapes=[[2,2], []],
)
# Enqueues (add) one element to this queue.
enqueue_op = queue.enqueue([X, Y])
# Dequeues (remove) one element from this queue.
dequeue_op = queue.dequeue()
# Dequeues and concatenates `n` elements from this queue.
Xs, Ys = queue.dequeue_many(n=batch_size)

# dummy train operation
train_op = tf.reduce_mean(Xs) * Ys

init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

# A coordinator for threads
coord = tf.train.Coordinator()

def enqueue_thread():
    # Context manager to request stop when an Exception is raised.
    with coord.stop_on_exception():
        while not coord.should_stop():
            x, y = get_single_xy()
            sess.run(enqueue_op, feed_dict={X: x, Y: y})

available_threads = 5
for _ in range(available_threads):
    threading.Thread(target=enqueue_thread).start()

for epoch in range(n_epochs):
    start_time = time.time()
    sess.run(train_op)
    time.sleep(train_timeout)
    print('Time:', time.time() - start_time)
# https://github.com/tensorflow/tensorflow/issues/2514#issuecomment-221934925
import tensorflow as tf
from threading import Thread

batch_size = 2
get_single_xy_timeout = 10
train_timeout = 3
n_steps = 5
n_threads = 3
q_capacity = 3


# dummy function simulating loading, preprocessing
# and other operations needed for the X and Y
class Dataset():
    def __init__(self):
        self.counter = 0
        self.x_dtype = tf.float32
        self.x_shape = [2,2]
        self.y_dtype = tf.float32
        self.y_shape = []
        self.dtypes = [self.x_dtype, self.y_dtype]
        self.shapes = [self.x_shape, self.y_shape]

    def get_single_xy(self, train=True):
        # can have a state
        self.counter += 1
        time.sleep(get_single_xy_timeout)
        # types of x, y have to match queue
        x = np.random.rand(2,2)
        if train:
            y = np.random.rand()
        else:
            y = 100000.
        return [x, y]

dataset = Dataset()

q_train = tf.FIFOQueue(
    capacity=q_capacity,
    dtypes=dataset.dtypes,
    shapes=dataset.shapes
)
X_train = tf.placeholder(dataset.x_dtype, dataset.x_shape)
Y_train = tf.placeholder(dataset.y_dtype, dataset.y_shape)
enqueue_train = q_train.enqueue([X_train, Y_train])


q_test = tf.FIFOQueue(
    capacity=q_capacity,
    dtypes=dataset.dtypes,
    shapes=dataset.shapes
)
X_test = tf.placeholder(dataset.x_dtype, dataset.x_shape)
Y_test = tf.placeholder(dataset.y_dtype, dataset.y_shape)
enqueue_test = q_test.enqueue([X_test, Y_test])


q_selector = tf.placeholder(tf.int32, [])
q = tf.QueueBase.from_list(q_selector, [q_train, q_test])
dequeue_op = q.dequeue()
X, Y = q.dequeue_many(n=batch_size)


# dummy train operation
train_op = tf.reduce_mean(tf.reduce_mean(X) * Y)



init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)
# Create a coordinator, launch the queue runner threads.
coordinator = tf.train.Coordinator()

def train_pusher():
    with coordinator.stop_on_exception():
        while not coordinator.should_stop():
            x, y = dataset.get_single_xy(train=True)
            sess.run(enqueue_train, { X_train: x, Y_train: y })

def test_pusher():
    with coordinator.stop_on_exception():
        while not coordinator.should_stop():
            x, y = dataset.get_single_xy(train=False)
            sess.run(enqueue_test, { X_test: x, Y_test: y })

threads = [Thread(target=train_pusher) for i in range(n_threads)]
threads += [Thread(target=test_pusher) for i in range(n_threads)]
[t.start() for t in threads]


try:
    for step in range(n_steps):
        # if something not right stop
        if coordinator.should_stop():
            break

        start_time = time.time()
        result = sess.run(train_op, {q_selector: 0})
        time.sleep(train_timeout)
        print('Result', result, 'Time:', time.time() - start_time)
    
    for step in range(n_steps):
        start_time = time.time()
        result = sess.run(train_op, {q_selector: 1})
        time.sleep(train_timeout)
        print('Result', result, 'Time:', time.time() - start_time)
        
except Exception as e:
    # Report exceptions to the coordinator.
    coordinator.request_stop(e)
finally:
    # Terminate as usual.  It is innocuous to request stop twice.
    coordinator.request_stop()
    [t.join(0) for t in threads]
sess.close()
import tensorflow as tf

batch_size = 2
get_single_xy_timeout = 10
train_timeout = 3
n_epochs = 5
n_threads = 5

# dummy function simulating loading, preprocessing
# and other operations needed for the X and Y
class DataFetcher():
    def __init__(self):
        self.counter = 0

    def get_single_xy(self):
        # can have a state
        self.counter += 1
        print(self.counter)
        time.sleep(get_single_xy_timeout)
        # types of x, y have to match queue
        x = np.random.rand(2,2).astype(np.float32)
        y = np.random.rand(1).astype(np.float32)[0]
        return [x, y]

data = DataFetcher()

# create the query
queue = tf.FIFOQueue(
    capacity=15,
    dtypes=[tf.float32, tf.float32],
    shapes=[[2,2], []],
)

python_data_op = tf.py_func(data.get_single_xy, inp=[], Tout=[tf.float32, tf.float32])

# Enqueues (add) one element to this queue.
enqueue_op = queue.enqueue(python_data_op)
# Dequeues (remove) one element from this queue.
dequeue_op = queue.dequeue()
# Dequeues and concatenates `n` elements from this queue.
X, Y = queue.dequeue_many(n=batch_size)

# dummy train operation
train_op = tf.reduce_mean(tf.reduce_mean(X) * Y)

# Create a queue runner that will run 4 threads in parallel to enqueue examples
qr = tf.train.QueueRunner(queue, [enqueue_op] * n_threads)

init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
threads = qr.create_threads(sess, coord=coord, start=True)

try:
    for step in range(n_epochs):
        if coord.should_stop():
            break
        # inside the train loop
        start_time = time.time()
        result = sess.run(train_op)
        time.sleep(train_timeout)
        print('Result', result, 'Time:', time.time() - start_time)
        
except Exception as e:
    # Report exceptions to the coordinator.
    coord.request_stop(e)
finally:
    # Terminate as usual.  It is innocuous to request stop twice.
    coord.request_stop()
    coord.join(threads)
sess.close()