回到首页

tensorflow-tfrecord实践

import tensorflow as tf
import numpy as np
import IPython.display as display
import cv2

# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
	'''Returns a bytes_list from a string/byte'''
	if isinstance(value,type(tf.constant(0))):
		value=value.numpy() # BytesList won't unpack a string from an EagerTensor
	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
	'''Returns a float_list from a float/double'''
	return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
	'''Returns an int64_list from a bool/enum/int/uint'''
	return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
# import pdb;pdb.set_trace()
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
img1=cv2.imread(cat_in_snow)
cv2.imshow('img',img1)
cv2.waitKey()
img2=cv2.imread(williamsburg_bridge)
cv2.imshow('img',img2)
cv2.waitKey()
image_labels={cat_in_snow:0,williamsburg_bridge:1,}

# This is an example, just using the cat image
image_string=open(cat_in_snow,'rb').read()

label=image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant
def image_example(image_string,label):
	image_shape=tf.image.decode_jpeg(image_string).shape
	feature={
	'height': _int64_feature(image_shape[0]),
	'width':_int64_feature(image_shape[1]),
	'depth':_int64_feature(image_shape[2]),
	'label':_int64_feature(label),
	'image_raw':_bytes_feature(image_string),
	}
	return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string,label)).split('\n')[:15]:
	print(line)
print('...')

# Write the raw image files to 'images.tfrecords'. First, process the two images into 'tf.Example' messages. Then, write to a '.tfrecords' file.
record_file='images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
	for filename,label in image_labels.items():
		image_string=open(filename,'rb').read()
		tf_example=image_example(image_string,label)
		writer.write(tf_example.SerializeToString())

# Read the tfrecords file
raw_image_dataset=tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features
image_feature_description={
	'height': tf.io.FixedLenFeature([],tf.int64),
	'width': tf.io.FixedLenFeature([],tf.int64),
	'depth':tf.io.FixedLenFeature([],tf.int64),
	'label': tf.io.FixedLenFeature([],tf.int64),
	'image_raw': tf.io.FixedLenFeature([],tf.string),
}

def _parse_image_function(example_proto):
	# Parse the input tf.Example proto using the dictionary above
	return tf.io.parse_single_example(example_proto,image_feature_description)

parsed_image_dataset=raw_image_dataset.map(_parse_image_function)

for image_features in parsed_image_dataset:
	# import pdb;pdb.set_trace()
	cv2.imshow('img',tf.image.decode_jpeg(image_features['image_raw']).numpy()[:,:,::-1])
	cv2.waitKey()
cv2.destroyWindow('img')
,把重要的数据打包存在一段连续的包里,加快数据集的吞吐速度,尤其适合通过网络进行流式传输的数据,结构化数据

参考链接:TFRecord 和 tf.Example

本文创建于2022.11.1/18.22,修改于2022.11.1/18.22