版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014256231/article/details/79036422
在做Tensorflow官方文档cifar10,运行cifar10_train.py时,网速太慢,三个小时只下载了30%,于是找方法自己下载并解压。
步骤如下:
先下载文件(包括cifar-10-binary.tar.gz 和 用到的cifar10.py cifar10_train.py等文件)
链接:https://pan.baidu.com/s/1eTwtvvc 密码:deky将cifar-10-binary.tar.gz拷贝到/tmp/cifar10_data/。如果之前运行过程序,需要将此路径下的cifar-10-batches-bin文件夹删除。
具体:
在终端进入下载文件的目录,之后
cp cifar-10-binary.tar.gz /tmp/cifar10_data/
rm -r /tmp/cifar10_data/cifar-10-batches-bin
之后,ls /tmp/cifar10_data,检查是否只有cifar-10-binary.tar.gz一个文件- 确认无误后,下载cifar10.py and cifar10_input.py,在同路径下运行代码:
#!/usr/bin/python
#-*-coding:utf-8-*-
import cifar10
import cifar10_input
import tensorflow as tf
import os
import tarfile
#解压缩
filepath = '/tmp/cifar10_data/cifar-10-binary.tar.gz'
dest_directory = '/tmp/cifar10_data'
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
if not os.path.exists(extracted_dir_path):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
data_dir = '/tmp/cifar10_data/cifar-10-batches-bin'
batch_size = 100
#生成CIFAR-10的训练数据和训练标签数据
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)
#生成CIFAR-10的测试数据和测试标签数据
images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size)
sess = tf.InteractiveSession()
tf.global_variables_initializer()
tf.train.start_queue_runners()
print(images_train)
print(images_test)
- 代码运行结束后,可以再进入目录看一下,可以看到将cifar10数据集成功解压啦~