Python:深度学习杂草识别 Untitled3.ipynb

发布时间 2023-06-12 13:55:13作者: 王哲MGG_AI
import os
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import cv2 as cv
import numpy as np
# from sklearn.model_selection import train_test_split
from torchvision import transforms as T
from PIL import Image
import torch
import time
from tqdm import tqdm
from torchvision import models
from torch import nn
# from sklearn.utils import shuffle
# import random
import matplotlib.pyplot as plt
from vgg16_cml import VGG16_cml
from torch.optim import lr_scheduler
##################################
这段代码导入了一些常用的库,包括os、pandas、torch、cv2、numpy、torchvision、PIL、time、tqdm、matplotlib和vgg16_cml。这些库提供了一些基础功能,如文件操作(os)、数据处理(pandas)、深度学习框架(torch)、图像处理(cv2、PIL)、数值计算(numpy)、数据可视化(matplotlib)等。此外,还导入了一些特定的类和函数,如Dataset和DataLoader(用于构建数据集和加载数据)、transforms(用于图像预处理)、models(用于加载预训练模型)、nn(用于构建神经网络)等。最后,还导入了一个名为VGG16_cml的模块,可能是一个自定义的模型。
##################################
def data_set(file_ori):
image_path=[]
label=[]
for i,file in enumerate(os.listdir(file_ori)):
for j in os.listdir(os.path.join(file_ori,file)):
image_path.append(os.path.join(file_ori,file,j))
label.append(i)
tmp=np.concatenate([[image_path],[label]],0).transpose()
np.random.seed(0)
np.random.shuffle(tmp)
# image_list=tmp[:,0]
# label_list=tmp[:,1]
return tmp
##################################

这段代码定义了一个名为data_set的函数,它接受一个参数file_ori,表示数据集所在的目录。函数首先定义了两个空列表image_pathlabel,用于存储图像路径和标签。然后,使用两层循环遍历数据集目录下的所有子目录和文件。对于每个文件,将其路径添加到image_path列表中,并将其所在目录的索引添加到label列表中。最后,使用numpy库中的函数将两个列表合并为一个二维数组,并随机打乱顺序。函数返回打乱顺序后的二维数组。