如何为您的机器学习问题选择正确的预训练模型

作者: 不靠谱的猫 2019-05-07 11:18:51

 在这篇文章中,我们将简要介绍一下迁移学习是什么,以及如何使用它。

什么是迁移学习?

迁移学习是使用预训练模型解决深度学习问题的艺术。

迁移学习是一种机器学习技术,你可以使用一个预训练好的神经网络来解决一个问题,这个问题类似于网络最初训练用来解决的问题。例如,您可以利用构建好的用于识别狗的品种的深度学习模型来对狗和猫进行分类,而不是构建您自己的模型。这可以为您省去寻找有效的神经网络体系结构的痛苦,可以为你节省花在训练上的时间,并可以保证有良好的结果。也就是说,你可以花很长时间来制作一个50层的CNN来***地区分你的猫和狗,或者你可以简单地使用许多预训练好的图像分类模型。

使用预训练模型的三种不同方式

主要有三种不同的方式可以重新定位预训练模型。他们是,

  1. 特征提取 。
  2. 复制预训练的网络的体系结构。
  3. 冻结一些层并训练其他层。

特征提取:这里我们所需要做的就是改变输出层,以给出cat和dog的概率(或者您的模型试图将内容分类到的类的数量),而不是最初训练它将内容分类到的数千个类。当我们试图训练模型所使用的数据与预训练的模型最初所训练的数据非常相似且数据集的大小很小时,这是理想的。这种机制称为固定特征提取。我们只对添加的新输出层进行重新训练,并保留每一层的权重。

复制预训练网络的架构 :在这里,我们定义了一个与预训练模型具有相同体系结构的机器学习模型,该模型在执行与我们试图实现的任务类似的任务时显示了出色的结果,并从头开始训练它。我们从预训练的模型中丢弃每一层的权重,然后根据我们的数据重新训练整个模型。当我们有大量的数据要训练时,我们会采用这种方法,但它与训练前的模型所训练的数据并不十分相似。

冻结一些层并训练其他层:我们可以选择冻结一个预训练模型的初始k层,只训练最顶层的n-k层。我们保持初始值的权重与预训练模型的权重相同且不变,并对数据的高层进行再训练。当数据集较小且数据相似度较低时,采用该方法。较低的层主要关注可以从数据中提取的最基本的信息,因此可以将其用于其他问题,因为基本级别的信息通常是相同的。

另一种常见情况是数据相似性高且数据集也很大。在这种情况下,我们保留模型的体系结构和模型的初始权重。然后,我们对整个模型进行再训练,以更新预训练模型的权重,以更好地适应我们的特定问题。这是使用迁移学习的理想情况。

下图显示了随着数据集大小和数据相似性的变化而采用的方法。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

PyTorch中的迁移学习

在torchvision.models模块下,PyTorch中有八种不同的预训练模型。他们是 :

  1. AlexNet
  2. VGG
  3. RESNET
  4. SqueezeNet
  5. DenseNet
  6. Inception v3
  7. GoogLeNet
  8. ShuffleNet v2

这些都是为图像分类而构建的卷积神经网络,在ImageNet数据集上进行训练。ImageNet是根据WordNet层次结构组织的图像数据库,包含14,197,122张属于21841类的图像。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

由于PyTorch中的所有预训练模型都针对相同的任务在相同的数据集上进行训练,所以我们选择哪一个并不重要。让我们选择ResNet网络,看看如何在前面讨论的不同场景中使用它。

用于图像识别的ResNet或深度残差学习在pytorch、ResNet -18、ResNet -34、ResNet -50、ResNet -101和ResNet -152上有五个版本。

让我们从torchvision下载ResNet-18。

  1. import torchvision.models as models 
  2. model = models.resnet18(pretrained=True) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

以下是我们刚刚下载的模型。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在,让我们看看尝试,看看如何针对四个不同的问题训​​练这个模型。

数据集很小,数据相似性很高

考虑这个kaggle数据集(https://www.kaggle.com/mriganksingh/cat-images-dataset)。这包括猫的图像和其他非猫的图像。它有209个像素64*64*3的训练图像和50个测试图像。这显然是一个非常小的数据集,但我们知道ResNet是在大量动物和猫图像上训练的,所以我们可以使用ResNet作为固定特征提取器来解决我们的猫与非猫的问题。

  1. num_ftrs = model.fc.in_features 
  2. num_ftrs 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

Out: 512

  1. model.fc.out_features 

Out: 1000

我们需要冻结除***一层之外的所有网络。我们需要设置requires_grad = False来冻结参数,这样就不会在backward()中计算梯度。新构造模块的参数默认为requires_grad=True。

  1. for param in model.parameters(): 
  2.  param.requires_grad = False 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

由于我们只需要***一层提供两个概率,即图像的概率是否为cat,我们可以重新定义***一层中的输出特征数。

  1. model.fc = nn.Linear(num_ftrs, 2

这是我们模型的新架构。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

我们现在要做的就是训练模型的***一层,我们将能够使用我们重新定位的vgg16来预测图像是否是猫,而且数据和训练时间都非常少。

数据的大小很小,数据相似性也很低

考虑来自(https://www.kaggle.com/kvinicki/canine-coccidiosis),这个数据集包含了犬异孢球虫和犬异孢球虫卵囊的图像和标签,异孢球虫卵囊是一种球虫寄生虫,可感染狗的肠道。它是由萨格勒布兽医学院创建的。它包含了两种寄生虫的341张图片。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

这个数据集很小,而且不是Imagenet中的一个类别。在这种情况下,我们保留预先训练好的模型架构,冻结较低的层并保留它们的权重,并训练较低的层更新它们的权重以适应我们的问题。

  1. count = 0 
  2. for child in model.children(): 
  3.  count+=1 
  4. print(count) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

Out: 10

ResNet18共有10层。让我们冻结前6层。

  1. count = 0 
  2. for child in model.children(): 
  3.  count+=1 
  4.  if count < 7
  5.  for param in child.parameters(): 
  6.  param.requires_grad = False 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在我们已经冻结了前6层,让我们重新定义最终输出层,只给出2个输出,而不是1000。

  1. model.fc = nn.Linear(num_ftrs, 2

这是更新的架构。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在,训练这个机器学习模型,更新***4层的权重。

数据集的大小很大,但数据相似性非常低

考虑这个来自kaggle,皮肤癌MNIST的数据集:HAM10000

其具有超过10015个皮肤镜图像,属于7种不同类别。这不是我们在Imagenet中可以找到的那种数据。

这就是我们只保留模型架构而不保留来自预训练模型的任何权重的地方。让我们重新定义输出层,将项目分类为7个类别。

  1. model.fc = nn.Linear(num_ftrs, 7

这个模型需要几个小时才能在没有GPU的机器上进行训练,但是如果你运行足够的时代,你仍然会得到很好的结果,而不必定义你自己的模型架构。

数据大小很大,数据相似性很高

考虑来自kaggle 的鲜花数据集(https://www.kaggle.com/alxmamaev/flowers-recognition)。它包含4242个花卉图像。图片分为五类:洋甘菊,郁金香,玫瑰,向日葵,蒲公英。每个类大约有800张照片。

这是应用迁移学习的理想情况。我们保留了预训练模型的体系结构和每一层的权重,并训练模型更新权重以匹配我们的特定问题。

  1. model.fc = nn.Linear(num_ftrs, 5
  2. best_model_wts = copy.deepcopy(model.state_dict()) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

我们从预训练的模型中复制权重并初始化我们的模型。我们使用训练和测试阶段来更新这些权重。

  1. for epoch in range(num_epochs): 
  2.   
  3.  print(‘Epoch {}/{}’.format(epoch, num_epochs — 1)) 
  4.  print(‘-’ * 10
  5.  for phase in [‘train’, ‘test’]: 
  6.   
  7.  if phase == 'train'
  8.  scheduler.step() 
  9.  model.train()  
  10.  else
  11.  model.eval() 
  12.  running_loss = 0.0 
  13.  running_corrects = 0 
  14.  for inputs, labels in dataloaders[phase]: 
  15.   
  16.  inputs = inputs.to(device) 
  17.  labels = labels.to(device) 
  18.  optimizer.zero_grad() 
  19.  with torch.set_grad_enabled(phase == ‘train’): 
  20.   
  21.  outputs = model(inputs) 
  22.  _, preds = torch.max(outputs, 1
  23.  loss = criterion(outputs, labels) 
  24.   
  25.  if phase == ‘train’: 
  26.  loss.backward() 
  27.  optimizer.step() 
  28.  running_loss += loss.item() * inputs.size(0
  29.  running_corrects += torch.sum(preds == labels.data) 
  30.   
  31.  epoch_loss = running_loss / dataset_sizes[phase] 
  32.  epoch_acc = running_corrects.double() / dataset_sizes[phase] 
  33.  print(‘{} Loss: {:.4f} Acc: {:.4f}’.format( 
  34.  phase, epoch_loss, epoch_acc)) 
  35.   
  36.  if phase == ‘test’ and epoch_acc > best_acc: 
  37.  best_acc = epoch_acc 
  38.  best_model_wts = copy.deepcopy(model.state_dict()) 
  39. print(‘Best val Acc: {:4f}’.format(best_acc)) 
  40. model.load_state_dict(best_model_wts) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

这种机器学习模式也需要几个小时的训练,但即使只有一个训练epoch ,也会产生出色的效果。

您可以按照相同的原则在任何其他平台上使用任何其他预训练的网络执行迁移学习。本文随机挑选了Resnet和pytorch。任何其他CNN都会给出类似的结果。希望这可以节省您使用计算机视觉解决现实世界问题的痛苦时间。

机器学习 人工智能 计算机
上一篇:如何正确实施人工智能? 下一篇:面部识别技术受争议,人工智能该如何控制?
评论
取消
暂无评论,快去成为第一个评论的人吧

更多资讯推荐

大数据和人工智能如何协同工作

人工智能和机器学习如何帮助组织从大数据中获得更好的业务见解?需要了解人工智能和大数据分析的下一步发展。大数据技术并不像几年前那样广受关注,但这并不意味着大数据技术没有得到发展。如果说有什么不同的话,那就是大数据的规模正在变得越来越大。

Kevin Casey ·  20h前
麻省理工学院开发出组装机器人:未来可建造太空殖民地

麻省理工学院博士生本杰明·杰内特(Benjamin Jenett)和原子中心的尼尔·格申费尔德教授(Neil Gershenfeld)在《电气电子工程师学会机器人与自动化快报》科学期刊上发表报告称,开发出一种组装机器人原型,它可以用很小的零件制成大型结构。

技术力量 ·  20h前
刷脸取件被小学生“破解”!丰巢紧急下线

近日,#小学生发现刷脸取件bug#的话题引发关注!这是真的吗?都市快报《好奇实验室》进行了验证。

好奇实验室 ·  21h前
深度学习/计算机视觉常见的8个错误总结及避坑指南

人类并不是完美的,我们经常在编写软件的时候犯错误。有时这些错误很容易找到:你的代码根本不工作,你的应用程序会崩溃。但有些 bug 是隐藏的,很难发现,这使它们更加危险。

skura ·  21h前
AI艺术日渐繁荣,未来何去何从?

利用人工智能创作而成的画作近年来越来越受瞩目,有的作品甚至能在知名拍卖行拍得高价。但这类作品仍有不少问题需要解答,比如它的作者是开发出算法的程序员还是计算机呢?AI艺术的市场未来将走向何方呢?

网易智能 ·  22h前
人工智能如何改变医疗保健行业

人工智能医疗公司的首席执行官对于人工智能在医学上的应用,如何购买人工智能解决方案,以及人工智能在医疗领域的未来发展进行了阐述。

James Maguire ·  1天前
2019年深度学习自然语言处理十大发展趋势 精选

自然语言处理在深度学习浪潮下取得了巨大的发展,FloydHub 博客上Cathal Horan介绍了自然语言处理的10大发展趋势,是了解NLP发展的非常好的文章。

HU数据派 ·  1天前
4 分钟!OpenAI 的机器手学会单手解魔方了,完全自学无需编程 精选

OpenAI 的机器手学会单手解魔方了,而且还原一个三阶魔方全程只花了 4 分钟,其灵巧程度让人自叹不如。

佚名 ·  1天前
Copyright©2005-2019 51CTO.COM 版权所有 未经许可 请勿转载