使用TensorFlow-Slim进行图像分类的实现

论坛 期权论坛     
nimin   2020-1-4 20:43   226   0
<p>参考 <a href="https://github.com/tensorflow/models/tree/master/slim" rel="external nofollow" target="_blank">https://github.com/tensorflow/models/tree/master/slim</a></p>
<p>使用TensorFlow-Slim进行图像分类</p>
<p><span style="color: #ff0000"><strong>准备</strong></span><br>
</p>
<p>安装TensorFlow</p>
<p>参考 <a href="https://www.tensorflow.org/install/" rel="external nofollow" target="_blank">https://www.tensorflow.org/install/</a></p>
<p>如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本</p>
<div class="blockcode">
<pre class="brush:py;">
wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl</pre>
</div>
<p>下载TF-slim图像模型库</p>
<div class="blockcode">
<pre class="brush:py;">
cd $WORKSPACE
git clone https://github.com/tensorflow/models/

</pre>
</div>
<p>准备数据</p>
<p>有不少公开数据集,这里以官网提供的Flowers为例。</p>
<p>官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。</p>
<div class="blockcode">
<pre class="brush:bash;">
cd $WORKSPACE/data
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flower_photos.tgz

</pre>
</div>
<p>数据集文件夹结构如下:</p>
<div class="blockcode">
<pre class="brush:plain;">
flower_photos
├── daisy
│  ├── 100080576_f52e8ee070_n.jpg
│  └── ...
├── dandelion
├── LICENSE.txt
├── roses
├── sunflowers
└── tulips

</pre>
</div>
<p>由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。</p>
<p>Python代码:</p>
<div class="blockcode">
<pre class="brush:py;">
import os

class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
data_dir = 'flower_photos/'
output_path = 'list.txt'

fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
  images_list = os.listdir(data_dir + class_name)
  for image_name in images_list:
    fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))

fd.close()

</pre>
</div>
<p>为了方便后期查看label标签,也可以定义labels.txt:</p>
<div class="blockcode">
<pre class="brush:plain;">
daisy
dandelion
roses
sunflowers
tulips

</pre>
</div>
<p>随机生成训练集与验证集:</p>
<p>Python代码:</p>
<div class="blockcode">
<pre class="brush:py;">
import random

_NUM_VALIDATION = 350
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'

fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)

fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
  fd.write(line)

fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
  fd.write(line)

fd.close()

</pre>
</div>
<p>生成TFRecord数据:</p>
<p>Python代码:</p>
<div class="blockcode">
<pre class="brush:py;">
import sys
sys.path.insert(0, '../models/slim/')
from datasets import dataset_utils
import math
import os
import tensorflow as tf

def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
  fd = open(list_path)
  lines = [line.split() for line in fd]
  fd.close()
  num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    decode_jpeg_data = tf.placeholder(dtype=tf.string)
    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        output_path = os.path.join(output_dir,
          'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
        tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
        start_ndx = shard_id * num_per_shard
        end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
        for i in range(start_ndx, end_ndx):
          sys.stdout.write('\r&gt;&gt; Converting image {}/{} shard {}'.format(
            i + 1, len(lines), shard_id))
          sys.stdout.flush()
          image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
          image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
          height, width = image.shape[0], image.shape[1]
          example = dataset_utils.image_to_tfexample(
            image_data, b'jpg', height, width, int(lines[i][1]))
          tfrecord_writer.write(example.SerializeToString())
        tfrecord_writer.close()
  sys.stdout.write('\n')
  sys.stdout.flush()

os.system('mkdir -p train')
convert_dataset('list_train.txt', 'flower_photos', 'train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt', 'flower_photos', 'val/')

</pre>
</div>
<p>得到的文件夹结构如下:</p>
<div class="blockcode">
<pre class="brush:plain;">
data
├── flower_photos
├── labels.txt
├── list_train.txt
├── list.txt
├── list_val.txt
├── train
│  ├── data_00000-of-00005.tfrecord
│  ├── ...
│  └── data_00004-of-00005.tfrecord
└── val
  ├── data_00000-of-00005.tfrecord
  ├── ...
  └── data_00004-of-00005.tfrecord

</pre>
</d
分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:1012780
帖子:202556
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP