import os
from urllib.request import urlretrieve
import tarfile
from scipy. io import loadmat2
from shutil import copyfile
import glob
import numpy as np
"""
函数说明:按照分类(labels)复制未分组的图片到指定的位置10
Parameters:
data path - 数据存放目录
labels - 数据对应的标签,需要按标签放到不同的目录
"""
def copy_data_files(data path, labels) :
if not os. path, exists( data path) :
os.mkdir(data path)
# 创建分类目录
for i in range(0,102) :
os.mkdir(os.path.join( data path, str(i)))
for label in labels:
src path = str(label[0])
dst path = os.path. join(data path, label[1], src path. split(os. sep)[ - 1])
copyfile(src path, dst path)
if_name_ _== '_main_':
# 检查本地数据集目录是否存在,若不存在,则需创建
data set path = "./data'
if not os. path. exists( data set path) :
os.mkdir(data set path)
#下载 102 Category Elower 数据集并解压
flowers archive file = "102flowers.tgz'
flowers_url frefix = "https://www,robots.ox.ac.uk/~vgg/data/flowers/102/'
flowers archive path = os.path, join(data set path, flowers archive file)
if not os path.exists(flowers archive path) :
print("正在下载图片文件...")
urlretrieve(flowers url frefix + flowers archive file, flowers archive path)
print("图片文件下载完成.")
print("正在解压图片文件...")
tarfile. open(flowers archive path)..extractall(path = data set_path)
print("图片文件解压完成,")
# 下载标识文件,标识不同文件的类别
flowers labels file = "imagelabels.mat'
flowers labels path = os.path. join(data set path, flowers labels file)
if not os.path.exists(flowers labels path) :
print("正在下载标识文件...")
urlretrieve(flowers url frefix + flowers labels file, flowers labels path)
print("标识文件下载完成")
flower_labels = loadmat(flowers_labels_path)['labels'][0] - 1
#下载数据集分类文件,包含训练集、验证集和测试集
sets splits file = "setid.mat"
sets splits_path = os.path. join(data set path, sets splits file)
if not os.path,exists( sets splits path) :
print("正在下载数据集分类文件...")
urlretrieve(flowers url frefix + sets splits file, sets splits path)
print("数据集分类文件下载完成")
sets_splits = loadmat( sets splits path)
# 由于数据集分类文件中测试集数量比训练集多,所以进行了对调
train set = sets splits['tstid'][0] - 1
valid set = sets splits[ 'valid'][0] - 1
test_set = sets splits['trnid'][0] - 1
# 获取图片文件名并找到图片对应的分类标识
image files = sorted(glob.glob(os.path. join(data set path, 'jpg', ' x .jpg')))
image labels = np.array([i for i in zip(image files, flower labels)])
# 将训练集、验证集和测试集分别放在不同的目录下
print("正在进行训练集的复制...")
copy_data files(os.path. join(data set path, 'train'), image labels[train set, :]
print("已完成训练集的复制,开始复制验证集...")
copy_data files(os.path. join(data_set_path, 'valid'), image labels[valid set, :]
print("已完成验证集的复制,开始复制测试集...")
copy_data files(os.path, join(data set_path, 'test'), image labels[test set, :]
print("已完成测试集的复制,所有的图片下载和预处理工作已完成.")