本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。
1、处理单张图片
我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。
import tensorflow as tf import matplotlib.pyplot as <b>本文来源gao@dai!ma.com搞$代^码!网7</b>plt def read_image(file_name): img = tf.read_file(filename=file_name) # 默认读取格式为uint8 print("img 的类型是",type(img)); img = tf.image.decode_jpeg(img,channels=0) # channels 为1得到的是灰度图,为0则按照图片格式来读 return img def main( ): with tf.device("/cpu:0"): # img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行 img_path='./1.jpg' img=read_image(img_path) with tf.Session() as sess: image_numpy=sess.run(img) print(image_numpy) print(image_numpy.dtype) print(image_numpy.shape) plt.imshow(image_numpy) plt.show() if __name__=="__main__": main() """
输出结果为:
img 的类型是 <class ‘tensorflow.python.framework.ops.Tensor’>
[[[196 219 209]
[196 219 209]
[196 219 209]
…[[ 71 106 42]
[ 59 89 39]
[ 34 63 19]
…
[ 21 52 46]
[ 15 45 43]
[ 22 50 53]]]
uint8
(675, 1200, 3)
“””
和tf.read_file用法相似的函数还有tf.gfile.FastGFile tf.gfile.GFile,只是要指定读取方式是’r’ 还是’rb’ 。
2、需要读取大量图像用于训练
这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:
def get_image_batch(data_file,batch_size): data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)] #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512) reader=tf.WholeFileReader() _,img_bytes=reader.read(filenames_queue) image=tf.image.decode_png(img_bytes,channels=1) #读取的是什么格式,就decode什么格式 #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape image.set_shape([180,180,1]) #set到原本已知图像的大小。或者直接通过tf.image.resize_images,tf.reshape() image=tf.image.convert_image_dtype(image,tf.float32) #预处理 下面的一句代码可以换成自己想使用的预处理方式 #image=tf.divide(image,255.0) return tf.train.batch([image],batch_size)
这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once
全部代码如下:
import tensorflow as tf import os def read_image(data_file,batch_size): data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)] filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30) reader=tf.WholeFileReader() _,img_bytes=reader.read(filenames_queue) image=tf.image.decode_jpeg(img_bytes,channels=1) image=tf.image.resize_images(image,(180,180)) image=tf.image.convert_image_dtype(image,tf.float32) return tf.train.batch([image],batch_size) def main( ): img_path=r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral' #本地的一个数据集目录,有足够的图像 img=read_image(img_path,batch_size=10) image=img[0] #取出每个batch的第一个数据 print(image) init=[tf.global_variables_initializer(),tf.local_variables_initializer()] with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord) try: while not coord.should_stop(): print(image.shape) except tf.errors.OutOfRangeError: print('read done') finally: coord.request_stop() coord.join(threads) if __name__=="__main__": main() """