[关闭]
@fangyang970206 2018-12-01T10:34:25.000000Z 字数 10724 阅读 1099

基于交通灯数据集的端到端分类

pytorch Classification Traffic_Light_Dataset


抓住11月的尾巴,这里写上昨天做的一个DL的作业吧,作业很简单,基于交通灯的图像分类,但这确是让你从0构建深度学习系统的好例子,很多已有的数据集都封装好了,直接调用,这篇文章将以pytorch这个深度学习框架一步步搭建分类系统。

软件包要求:
pytorch:0.4.0
torchsummarypip install torchsummary
cv2: pip install opencv-python
matplotlib
numpy

所有代码托管到github上,链接如下:https://github.com/FangYang970206/TL_Dataset_Classificationgit clone https://github.com/FangYang970206/TL_Dataset_Classification到本地。

1.数据集简介

数据集有10个类别,分别是红灯的圆球,向左,向右,向上和负例以及绿灯的圆球,向左,向右,向上和负例,如下图所示:
image_1cthr3cpb93e1ia7dnl1kl51tjh9.png-227.7kB
数据集的可通过如下链接进行下载:onedrivebaiduyungoogle
下完数据集后,解压到文件夹TL_Dataset_Classification-master中,得到一个新的文件夹TL_Dataset,可以看到TL_Dataset有以下目录:
image_1cthuhco8l4f1n05hen27a1ou216.png-14.6kB

2.代码实战

代码是在vscode上编写的,支持flask8,总共有9个文件,下面一一介绍。建议在看代码的时候从main.py文件开始看,大致脉络就清楚了。

2.1 model.py

对于一个深度学习系统来说,model应该是最初的想法,我们想构造什么样的模型来拟合数据集,所以先写model,代码如下:

  1. import torch.nn as nn
  2. from torchsummary import summary
  3. class A2NN(nn.Module):
  4. def __init__(self, ):
  5. super(A2NN, self).__init__()
  6. self.main = nn.Sequential(
  7. nn.Conv2d(3, 16, 3, 1, 1),
  8. nn.BatchNorm2d(16),
  9. nn.ReLU(inplace=True),
  10. nn.Conv2d(16, 32, 3, 1, 1),
  11. nn.MaxPool2d(2, 2),
  12. nn.BatchNorm2d(32),
  13. nn.ReLU(inplace=True),
  14. nn.Conv2d(32, 32, 3, 1, 1),
  15. nn.MaxPool2d(2, 2),
  16. nn.BatchNorm2d(32),
  17. nn.ReLU(inplace=True),
  18. nn.Conv2d(32, 64, 3, 1, 1),
  19. nn.MaxPool2d(2, 2),
  20. nn.BatchNorm2d(64),
  21. nn.ReLU(inplace=True),
  22. )
  23. self.linear = nn.Linear(4*4*64, 9)
  24. def forward(self, inp):
  25. x = self.main(inp)
  26. x = x.view(x.shape[0], -1)
  27. x = self.linear(x)
  28. return x
  29. if __name__ == "__main__":
  30. nn = A2NN()
  31. summary(nn, (3, 32, 32))

model代码不复杂,很简单,这里不多介绍,缺少基础的朋友还请自行补基础。

2.2 dataset.py

第二步我们要构建数据集类,pytorch封装了一个torch.utils.data.Dataset的类,我们可以重载__len____getitem__方法,来得到自己的数据集管道,__len__方法是返回数据集的长度,__getitem__是支持从0到len(self)互斥范围内的整数索引,返回的是索引对应的数据和标签。代码如下:

  1. import torch
  2. import cv2
  3. import torch.utils.data as data
  4. class_light = {
  5. 'Red Circle': 0,
  6. 'Green Circle': 1,
  7. 'Red Left': 2,
  8. 'Green Left': 3,
  9. 'Red Up': 4,
  10. 'Green Up': 5,
  11. 'Red Right': 6,
  12. 'Green Right': 7,
  13. 'Red Negative': 8,
  14. 'Green Negative': 8
  15. }
  16. class Traffic_Light(data.Dataset):
  17. def __init__(self, dataset_names, img_resize_shape):
  18. super(Traffic_Light, self).__init__()
  19. self.dataset_names = dataset_names
  20. self.img_resize_shape = img_resize_shape
  21. def __getitem__(self, ind):
  22. img = cv2.imread(self.dataset_names[ind])
  23. img = cv2.resize(img, self.img_resize_shape)
  24. img = img.transpose(2, 0, 1)-127.5/127.5
  25. for key in class_light.keys():
  26. if key in self.dataset_names[ind]:
  27. label = class_light[key]
  28. # pylint: disable=E1101,E1102
  29. return torch.from_numpy(img), torch.tensor(label)
  30. # pylint: disable=E1101,E1102
  31. def __len__(self):
  32. return len(self.dataset_names)
  33. if __name__ == '__main__':
  34. from torch.utils.data import DataLoader
  35. from glob import glob
  36. import os
  37. path = 'TL_Dataset/Green Up/'
  38. names = glob(os.path.join(path, '*.png'))
  39. dataset = Traffic_Light(names, (32, 32))
  40. dataload = DataLoader(dataset, batch_size=1)
  41. for ind, (inp, label) in enumerate(dataload):
  42. print("{}-inp_size:{}-label_size:{}".format(ind, inp.numpy().shape,
  43. label.numpy().shape))

2.3 util.py

在上面的dataset.py中,class初始化时,传入了dataset_names,所以utils.py文件中就通过get_train_val_names函数得到训练数据集和验证数据集的names,还有一个函数是检查文件夹是否存在,不存在建立文件夹。代码如下:

  1. import os
  2. from glob import glob
  3. def get_train_val_names(dataset_path, remove_names, radio=0.3):
  4. train_names = []
  5. val_names = []
  6. dataset_paths = os.listdir(dataset_path)
  7. for n in remove_names:
  8. dataset_paths.remove(n)
  9. for path in dataset_paths:
  10. sub_dataset_path = os.path.join(dataset_path, path)
  11. sub_dataset_names = glob(os.path.join(sub_dataset_path, '*.png'))
  12. sub_dataset_len = len(sub_dataset_names)
  13. val_names.extend(sub_dataset_names[:int(radio*sub_dataset_len)])
  14. train_names.extend(sub_dataset_names[int(radio*sub_dataset_len):])
  15. return {'train': train_names, 'val': val_names}
  16. def check_folder(path):
  17. if not os.path.exists(path):
  18. os.mkdir(path)

2.4 trainer.py

model构造好了,数据集也准备好了,现在就需要准备如果训练了,这就是trainer.py文件的作用,trainer.py构建了Trainer类,通过传入训练的一系列参数,调用Trainer.train函数进行训练,并返回loss,代码如下:

  1. import torch.nn as nn
  2. from torch.optim import Adam
  3. class Trainer:
  4. def __init__(self, model, dataload, epoch, lr, device):
  5. self.model = model
  6. self.dataload = dataload
  7. self.epoch = epoch
  8. self.lr = lr
  9. self.device = device
  10. self.optimizer = Adam(self.model.parameters(), lr=self.lr)
  11. self.criterion = nn.CrossEntropyLoss().to(self.device)
  12. def __epoch(self, epoch):
  13. self.model.train()
  14. loss_sum = 0
  15. for ind, (inp, label) in enumerate(self.dataload):
  16. inp = inp.float().to(self.device)
  17. label = label.long().to(self.device)
  18. self.optimizer.zero_grad()
  19. out = self.model.forward(inp)
  20. loss = self.criterion(out, label)
  21. loss.backward()
  22. loss_sum += loss.item()
  23. self.optimizer.step()
  24. print('epoch{}_step{}_train_loss_: {}'.format(epoch,
  25. ind,
  26. loss.item()))
  27. return loss_sum/(ind+1)
  28. def train(self):
  29. train_loss = self.__epoch(self.epoch)
  30. return train_loss

2.5 validator.py

trainer.py文件是用来进行训练数据集的,训练过程中,我们是需要有验证集来判断我们模型的训练效果,所以这里有validator.py文件,里面封装了Validator类,与Trainer.py类似,但不同的是,我们不训练,不更新参数,model处于eval模式,代码上会有一些跟Trainer不一样,通过调用Validator.eval函数返回loss,代码如下:

  1. import torch.nn as nn
  2. class Validator:
  3. def __init__(self, model, dataload, epoch, device, batch_size):
  4. self.model = model
  5. self.dataload = dataload
  6. self.epoch = epoch
  7. self.device = device
  8. self.batch_size = batch_size
  9. self.criterion = nn.CrossEntropyLoss().to(self.device)
  10. def __epoch(self, epoch):
  11. self.model.eval()
  12. loss_sum = 0
  13. for ind, (inp, label) in enumerate(self.dataload):
  14. inp = inp.float().to(self.device)
  15. label = label.long().to(self.device)
  16. out = self.model.forward(inp)
  17. loss = self.criterion(out, label)
  18. loss_sum += loss.item()
  19. return {'val_loss': loss_sum/(ind+1)}
  20. def eval(self):
  21. val_loss = self.__epoch(self.epoch)
  22. return val_loss

2.6 logger.py

我们想看整个学习的过程,可以通过看学习曲线来进行观察。所以这里写了一个logger.py文件,用来对训练loss和验证loss进行统计并画图。代码如下:

  1. import matplotlib.pyplot as plt
  2. import os
  3. class Logger:
  4. def __init__(self, save_path):
  5. self.save_path = save_path
  6. def update(self, Kwarg):
  7. self.__plot(Kwarg)
  8. def __plot(self, Kwarg):
  9. save_img_path = os.path.join(self.save_path, 'learning_curve.png')
  10. plt.clf()
  11. plt.plot(Kwarg['train_losses'], label='Train', color='g')
  12. plt.plot(Kwarg['val_losses'], label='Val', color='b')
  13. plt.xlabel('epoch')
  14. plt.ylabel('loss')
  15. plt.legend()
  16. plt.title('learning_curve')
  17. plt.savefig(save_img_path)

2.7 main.py

main.py文件将上面所有的东西结合到一起,代码如下:

  1. import torch
  2. import argparse
  3. from model import A2NN
  4. from dataset import Traffic_Light
  5. from utils import get_train_val_names, check_folder
  6. from trainer import Trainer
  7. from validator import Validator
  8. from logger import Logger
  9. from torch.utils.data import DataLoader
  10. def main():
  11. parse = argparse.ArgumentParser()
  12. parse.add_argument('--dataset_path', type=str, default='TL_Dataset/')
  13. parse.add_argument('--remove_names', type=list, default=['README.txt',
  14. 'README.png',
  15. 'Testset'])
  16. parse.add_argument('--img_resize_shape', type=tuple, default=(32, 32))
  17. parse.add_argument('--batch_size', type=int, default=1024)
  18. parse.add_argument('--lr', type=float, default=0.001)
  19. parse.add_argument('--num_workers', type=int, default=4)
  20. parse.add_argument('--epochs', type=int, default=200)
  21. parse.add_argument('--val_size', type=float, default=0.3)
  22. parse.add_argument('--save_model', type=bool, default=True)
  23. parse.add_argument('--save_path', type=str, default='logs/')
  24. args = vars(parse.parse_args())
  25. check_folder(args['save_path'])
  26. # pylint: disable=E1101
  27. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  28. # pylint: disable=E1101
  29. model = A2NN().to(device)
  30. names = get_train_val_names(args['dataset_path'], args['remove_names'])
  31. train_dataset = Traffic_Light(names['train'], args['img_resize_shape'])
  32. val_dataset = Traffic_Light(names['val'], args['img_resize_shape'])
  33. train_dataload = DataLoader(train_dataset,
  34. batch_size=args['batch_size'],
  35. shuffle=True,
  36. num_workers=args['num_workers'])
  37. val_dataload = DataLoader(val_dataset,
  38. batch_size=args['batch_size'],
  39. shuffle=True,
  40. num_workers=args['num_workers'])
  41. loss_logger = Logger(args['save_path'])
  42. logger_dict = {'train_losses': [],
  43. 'val_losses': []}
  44. for epoch in range(args['epochs']):
  45. print('<Main> epoch{}'.format(epoch))
  46. trainer = Trainer(model, train_dataload, epoch, args['lr'], device)
  47. train_loss = trainer.train()
  48. if args['save_model']:
  49. state = model.state_dict()
  50. torch.save(state, 'logs/nn_state.t7')
  51. validator = Validator(model, val_dataload, epoch,
  52. device, args['batch_size'])
  53. val_loss = validator.eval()
  54. logger_dict['train_losses'].append(train_loss)
  55. logger_dict['val_losses'].append(val_loss['val_loss'])
  56. loss_logger.update(logger_dict)
  57. if __name__ == '__main__':
  58. main()

2.8 compute_prec.py和submit.py

其实上面的七个文件,已经是结束了,下面两个文件一个是用来计算精确度的,一个是用来提交答案的。有兴趣可以看看。
compute_prec.py代码如下:

  1. import torch
  2. import numpy as np
  3. import argparse
  4. from model import A2NN
  5. from dataset import Traffic_Light
  6. from torch.utils.data import DataLoader
  7. from utils import get_train_val_names, check_folder
  8. def main():
  9. parse = argparse.ArgumentParser()
  10. parse.add_argument('--dataset_path', type=str, default='TL_Dataset/')
  11. parse.add_argument('--remove_names', type=list, default=['README.txt',
  12. 'README.png',
  13. 'Testset'])
  14. parse.add_argument('--img_resize_shape', type=tuple, default=(32, 32))
  15. parse.add_argument('--num_workers', type=int, default=4)
  16. parse.add_argument('--val_size', type=float, default=0.3)
  17. parse.add_argument('--save_path', type=str, default='logs/')
  18. args = vars(parse.parse_args())
  19. check_folder(args['save_path'])
  20. # pylint: disable=E1101
  21. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  22. # pylint: disable=E1101
  23. model = A2NN().to(device)
  24. model.load_state_dict(torch.load(args['save_path']+'nn_state.t7'))
  25. model.eval()
  26. names = get_train_val_names(args['dataset_path'], args['remove_names'])
  27. val_dataset = Traffic_Light(names['val'], args['img_resize_shape'])
  28. val_dataload = DataLoader(val_dataset,
  29. batch_size=1,
  30. num_workers=args['num_workers'])
  31. count = 0
  32. for ind, (inp, label) in enumerate(val_dataload):
  33. inp = inp.float().to(device)
  34. label = label.long().to(device)
  35. output = model.forward(inp)
  36. output = np.argmax(output.to('cpu').detach().numpy(), axis=1)
  37. label = label.to('cpu').numpy()
  38. count += 1 if output == label else 0
  39. print('precision: {}'.format(count/(ind+1)))
  40. if __name__ == "__main__":
  41. main()

submit.py代码如下:

  1. import torch
  2. import numpy as np
  3. import argparse
  4. import os
  5. import cv2
  6. from model import A2NN
  7. from utils import check_folder
  8. def main():
  9. parse = argparse.ArgumentParser()
  10. parse.add_argument('--dataset_path', type=str,
  11. default='TL_Dataset/Testset/')
  12. parse.add_argument('--img_resize_shape', type=tuple, default=(32, 32))
  13. parse.add_argument('--num_workers', type=int, default=4)
  14. parse.add_argument('--save_path', type=str, default='logs/')
  15. args = vars(parse.parse_args())
  16. check_folder(args['save_path'])
  17. # pylint: disable=E1101
  18. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  19. # pylint: disable=E1101
  20. model = A2NN().to(device)
  21. model.load_state_dict(torch.load(args['save_path']+'nn_state.t7'))
  22. model.eval()
  23. txt_path = os.path.join(args['save_path'], 'result.txt')
  24. with open(txt_path, 'w') as f:
  25. for i in range(20000):
  26. name = os.path.join(args['dataset_path'], '{}.png'.format(i))
  27. img = cv2.imread(name)
  28. img = cv2.resize(img, args['img_resize_shape'])
  29. img = img.transpose(2, 0, 1)-127.5/127.5
  30. img = torch.unsqueeze(torch.from_numpy(img).float(), dim=0)
  31. img = img.to(device)
  32. output = model.forward(img).to('cpu').detach().numpy()
  33. img_class = np.argmax(output, axis=1)
  34. f.write(name.split('/')[2] + ' ' + str(img_class[0]))
  35. f.write('\n')
  36. if __name__ == "__main__":
  37. main()

3. 代码如下运行

将数据集下载在文件夹TL_Dataset_Classification,解压后,在TL_Dataset_Classification文件中进入终端,运行命令:

  1. $ python main.py

如果还想计算精确度,在训练玩数据集之后,运行命令:

  1. $ python compute_prec.py

有运行可以到github上提issue或者在给我的邮箱867540289@qq.com发邮件。

4. 结果

学习曲线:
learning_curve.png-22.8kB
在测试集中,实现97.425%的精确度。(继续提升中)

5. 总结

好了,11月的尾巴到此结束,希望能对你学习深度学习问题和pytorch有所帮助。12月马上到,祝我数学考试顺利,也祝各位开开心心!

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注