tf.concat Function
Just how to use tf.concat Function, that function concatenates tensor along one dimension you specified like this.
tf.concat(list, dimension you want to concatenate). i.e.
If tensors, values[i].shape is [D0, D1, …, Daxis(i), …, Dn], the resulting concatenation’s shape is
[D0, D1, …, Raxis(concatenated), …, Dn]
In the case of the function,
Raxis = sum(Daxis(i))
So That is the data from input tensors is joined along the axis dimension.
Be careful of this following:
The number of dimensions of the input tensors must match and all dimension except axis must be equal
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""Example code for tf.concat(https://www.tensorflow.org/versions/r1.8/api_docs/python/tf/concat)
tf.concat(
values,
axis,
name='concat'
)
"""
import sys
import tensorflow as tf
print("=== Version checking ===")
print("The version of sys: \n{}".format(sys.version))
print("Tensorflow version: {}".format(tf.__version__))
print("========================")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
test1 = tf.constant([[1, 2, 3], [4, 5, 6]])
test2 = tf.constant([[7, 8, 9], [10, 11, 12]])
concat1 = tf.concat([test1, test2], 0)
concat2 = tf.concat([test1, test2], 1)
# Also, you can use negative index, the order is opposite to the correct order.
# I mean, concat1 is the same from concat1_1, and concat2 is the same from concat2_2
concat2_1 = tf.concat([test1, test2], -2)
concat2_2 = tf.concat([test1, test2], -1)
print("===== Tensor Shape ======")
print("test1: {}".format(test1))
print("test2: {}".format(test2))
print("concat1 with D-0: {}".format(concat1))
print("concat2 with D-1: {}".format(concat2))
print("concat2_1 with D-0(-2): {}".format(concat2_1))
print("concat2_2 with D-1(-1): {}".format(concat2_2))
1
2
3
4
5
6
7
with tf.Session() as sess:
print("test1: \n{}, The shape: {}".format(sess.run(test1), test1.shape))
print("test2: \n{}, The shape: {}".format(sess.run(test2), test2.shape))
print("concat1 with D-0: \n{}, The shape: {}".format(sess.run(concat1), concat1.shape))
print("concat2 with D-1: \n{}, The shape: {}".format(sess.run(concat2), concat2.shape))
print("concat2_1 with D-0(-2): \n{}, The shape: {}".format(sess.run(concat2_1), concat2_1.shape))
print("concat2_2 with D-1(-1): \n{}, The shape: {}".format(sess.run(concat2_2), concat2_2.shape))
For more example, with normal list of python.
1
2
3
4
5
6
7
8
9
10
11
12
te1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
te2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
conc1 = tf.concat([te1, te2], -1)
print("===== Python list ======")
print("te1: {}".format(te1))
print("te2: {}".format(te2))
print("\n===== Tensor Shape ======")
print("conc1 with D-2: {}".format(conc1))
1
2
with tf.Session() as sess:
print("conc1: \n{}, The shape: {}".format(sess.run(conc1), conc1.shape))
Reference
- tf.concat Function in Tensorflow apidoc from version r1.8