在做语义分割的时候会经过读取图像的步骤,根据 TensorFlow 官方教程 我使用了 tf.data.Dataset
这个 API。
根据官方读取图像的例子,一开始我的代码如下:
def load_image(filename, resized_shape):
'''
:param filename: 图像文件名
:param resized_shape: 缩放后图像大小
'''
image = tf.read_file(filename)
image = tf.image.decode_png(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_images(
image, size=resized_shape, method=tf.image.ResizeMethod.AREA)
return image
因为训练图像是 png
格式,因此使用了 decode_png
。
这里还使用了 resize_images
函数,因为在进行小范围测试时将图像缩小可以加快训练速度。
至此都是没有问题的,但是最近处理数据集的时候遇到了 jpg
格式的训练图像,之前看到 TensorFlow 有 decode_image
这个函数,好像可以自动判定图像格式然后 decode。
但是使用了之后报错了,这个错误是在 resize_images
的时候发生的:
ValueError: 'images' contains no shape.
根据 decode_image
的官方文档:
Returns:
Tensor with type uint8 with shape [height, width, num_channels] for BMP
, JPEG, and PNG images and shape [num_frames, height, width, 3] for GIF images.
返回的 Tensor 是有形状的,但是从调试中可以看到 shape 是 unknown 的,所以返回应该是没有形状,google 了一下也没有发现能说清这个问题的,因此这个函数暂时用不了了。
不过我发现了一种解决办法:
def load_image(filename, resized_shape):
'''
:param filename: 图像文件名
:param resized_shape: 缩放后图像大小
'''
image = tf.read_file(filename)
image = tf.cond(
tf.image.is_jpeg(image),
lambda: tf.image.decode_jpeg(image),
lambda: tf.image.decode_png(image))
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_images(
image, size=resized_shape, method=tf.image.ResizeMethod.AREA)
return image
就是使用 tf.cond()
函数,通过 tf.image.is_jpeg(image)
判断图像是不是 jpg
格式,如果是,就执行 decode_jpeg
,如果不是,就执行 decode_png
。因为在语义分割中,大部分训练图像都是 jpg
或 png
,很少会有其他格式的图像,因此用一个条件就够了。
其实 decode_image
函数里面就是使用的 tf.cond
来判断的,判断之后 decode,然后再 convert_image_dtype
,至于为什么返回没有 shape 我也不清楚。
还有一种方法是使用 tf.Tensor.set_shape ,我的代码里不方便使用这个方法, 所以就没有尝试。