如何为您的机器学习问题选择正确的预训练模型

作者: 不靠谱的猫 2019-05-07 11:18:51

 在这篇文章中,我们将简要介绍一下迁移学习是什么,以及如何使用它。

什么是迁移学习?

迁移学习是使用预训练模型解决深度学习问题的艺术。

迁移学习是一种机器学习技术,你可以使用一个预训练好的神经网络来解决一个问题,这个问题类似于网络最初训练用来解决的问题。例如,您可以利用构建好的用于识别狗的品种的深度学习模型来对狗和猫进行分类,而不是构建您自己的模型。这可以为您省去寻找有效的神经网络体系结构的痛苦,可以为你节省花在训练上的时间,并可以保证有良好的结果。也就是说,你可以花很长时间来制作一个50层的CNN来***地区分你的猫和狗,或者你可以简单地使用许多预训练好的图像分类模型。

使用预训练模型的三种不同方式

主要有三种不同的方式可以重新定位预训练模型。他们是,

  1. 特征提取 。
  2. 复制预训练的网络的体系结构。
  3. 冻结一些层并训练其他层。

特征提取:这里我们所需要做的就是改变输出层,以给出cat和dog的概率(或者您的模型试图将内容分类到的类的数量),而不是最初训练它将内容分类到的数千个类。当我们试图训练模型所使用的数据与预训练的模型最初所训练的数据非常相似且数据集的大小很小时,这是理想的。这种机制称为固定特征提取。我们只对添加的新输出层进行重新训练,并保留每一层的权重。

复制预训练网络的架构 :在这里,我们定义了一个与预训练模型具有相同体系结构的机器学习模型,该模型在执行与我们试图实现的任务类似的任务时显示了出色的结果,并从头开始训练它。我们从预训练的模型中丢弃每一层的权重,然后根据我们的数据重新训练整个模型。当我们有大量的数据要训练时,我们会采用这种方法,但它与训练前的模型所训练的数据并不十分相似。

冻结一些层并训练其他层:我们可以选择冻结一个预训练模型的初始k层,只训练最顶层的n-k层。我们保持初始值的权重与预训练模型的权重相同且不变,并对数据的高层进行再训练。当数据集较小且数据相似度较低时,采用该方法。较低的层主要关注可以从数据中提取的最基本的信息,因此可以将其用于其他问题,因为基本级别的信息通常是相同的。

另一种常见情况是数据相似性高且数据集也很大。在这种情况下,我们保留模型的体系结构和模型的初始权重。然后,我们对整个模型进行再训练,以更新预训练模型的权重,以更好地适应我们的特定问题。这是使用迁移学习的理想情况。

下图显示了随着数据集大小和数据相似性的变化而采用的方法。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

PyTorch中的迁移学习

在torchvision.models模块下,PyTorch中有八种不同的预训练模型。他们是 :

  1. AlexNet
  2. VGG
  3. RESNET
  4. SqueezeNet
  5. DenseNet
  6. Inception v3
  7. GoogLeNet
  8. ShuffleNet v2

这些都是为图像分类而构建的卷积神经网络,在ImageNet数据集上进行训练。ImageNet是根据WordNet层次结构组织的图像数据库,包含14,197,122张属于21841类的图像。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

由于PyTorch中的所有预训练模型都针对相同的任务在相同的数据集上进行训练,所以我们选择哪一个并不重要。让我们选择ResNet网络,看看如何在前面讨论的不同场景中使用它。

用于图像识别的ResNet或深度残差学习在pytorch、ResNet -18、ResNet -34、ResNet -50、ResNet -101和ResNet -152上有五个版本。

让我们从torchvision下载ResNet-18。

  1. import torchvision.models as models 
  2. model = models.resnet18(pretrained=True) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

以下是我们刚刚下载的模型。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在,让我们看看尝试,看看如何针对四个不同的问题训​​练这个模型。

数据集很小,数据相似性很高

考虑这个kaggle数据集(https://www.kaggle.com/mriganksingh/cat-images-dataset)。这包括猫的图像和其他非猫的图像。它有209个像素64*64*3的训练图像和50个测试图像。这显然是一个非常小的数据集,但我们知道ResNet是在大量动物和猫图像上训练的,所以我们可以使用ResNet作为固定特征提取器来解决我们的猫与非猫的问题。

  1. num_ftrs = model.fc.in_features 
  2. num_ftrs 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

Out: 512

  1. model.fc.out_features 

Out: 1000

我们需要冻结除***一层之外的所有网络。我们需要设置requires_grad = False来冻结参数,这样就不会在backward()中计算梯度。新构造模块的参数默认为requires_grad=True。

  1. for param in model.parameters(): 
  2.  param.requires_grad = False 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

由于我们只需要***一层提供两个概率,即图像的概率是否为cat,我们可以重新定义***一层中的输出特征数。

  1. model.fc = nn.Linear(num_ftrs, 2

这是我们模型的新架构。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

我们现在要做的就是训练模型的***一层,我们将能够使用我们重新定位的vgg16来预测图像是否是猫,而且数据和训练时间都非常少。

数据的大小很小,数据相似性也很低

考虑来自(https://www.kaggle.com/kvinicki/canine-coccidiosis),这个数据集包含了犬异孢球虫和犬异孢球虫卵囊的图像和标签,异孢球虫卵囊是一种球虫寄生虫,可感染狗的肠道。它是由萨格勒布兽医学院创建的。它包含了两种寄生虫的341张图片。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

这个数据集很小,而且不是Imagenet中的一个类别。在这种情况下,我们保留预先训练好的模型架构,冻结较低的层并保留它们的权重,并训练较低的层更新它们的权重以适应我们的问题。

  1. count = 0 
  2. for child in model.children(): 
  3.  count+=1 
  4. print(count) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

Out: 10

ResNet18共有10层。让我们冻结前6层。

  1. count = 0 
  2. for child in model.children(): 
  3.  count+=1 
  4.  if count < 7
  5.  for param in child.parameters(): 
  6.  param.requires_grad = False 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在我们已经冻结了前6层,让我们重新定义最终输出层,只给出2个输出,而不是1000。

  1. model.fc = nn.Linear(num_ftrs, 2

这是更新的架构。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在,训练这个机器学习模型,更新***4层的权重。

数据集的大小很大,但数据相似性非常低

考虑这个来自kaggle,皮肤癌MNIST的数据集:HAM10000

其具有超过10015个皮肤镜图像,属于7种不同类别。这不是我们在Imagenet中可以找到的那种数据。

这就是我们只保留模型架构而不保留来自预训练模型的任何权重的地方。让我们重新定义输出层,将项目分类为7个类别。

  1. model.fc = nn.Linear(num_ftrs, 7

这个模型需要几个小时才能在没有GPU的机器上进行训练,但是如果你运行足够的时代,你仍然会得到很好的结果,而不必定义你自己的模型架构。

数据大小很大,数据相似性很高

考虑来自kaggle 的鲜花数据集(https://www.kaggle.com/alxmamaev/flowers-recognition)。它包含4242个花卉图像。图片分为五类:洋甘菊,郁金香,玫瑰,向日葵,蒲公英。每个类大约有800张照片。

这是应用迁移学习的理想情况。我们保留了预训练模型的体系结构和每一层的权重,并训练模型更新权重以匹配我们的特定问题。

  1. model.fc = nn.Linear(num_ftrs, 5
  2. best_model_wts = copy.deepcopy(model.state_dict()) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

我们从预训练的模型中复制权重并初始化我们的模型。我们使用训练和测试阶段来更新这些权重。

  1. for epoch in range(num_epochs): 
  2.   
  3.  print(‘Epoch {}/{}’.format(epoch, num_epochs — 1)) 
  4.  print(‘-’ * 10
  5.  for phase in [‘train’, ‘test’]: 
  6.   
  7.  if phase == 'train'
  8.  scheduler.step() 
  9.  model.train()  
  10.  else
  11.  model.eval() 
  12.  running_loss = 0.0 
  13.  running_corrects = 0 
  14.  for inputs, labels in dataloaders[phase]: 
  15.   
  16.  inputs = inputs.to(device) 
  17.  labels = labels.to(device) 
  18.  optimizer.zero_grad() 
  19.  with torch.set_grad_enabled(phase == ‘train’): 
  20.   
  21.  outputs = model(inputs) 
  22.  _, preds = torch.max(outputs, 1
  23.  loss = criterion(outputs, labels) 
  24.   
  25.  if phase == ‘train’: 
  26.  loss.backward() 
  27.  optimizer.step() 
  28.  running_loss += loss.item() * inputs.size(0
  29.  running_corrects += torch.sum(preds == labels.data) 
  30.   
  31.  epoch_loss = running_loss / dataset_sizes[phase] 
  32.  epoch_acc = running_corrects.double() / dataset_sizes[phase] 
  33.  print(‘{} Loss: {:.4f} Acc: {:.4f}’.format( 
  34.  phase, epoch_loss, epoch_acc)) 
  35.   
  36.  if phase == ‘test’ and epoch_acc > best_acc: 
  37.  best_acc = epoch_acc 
  38.  best_model_wts = copy.deepcopy(model.state_dict()) 
  39. print(‘Best val Acc: {:4f}’.format(best_acc)) 
  40. model.load_state_dict(best_model_wts) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

这种机器学习模式也需要几个小时的训练,但即使只有一个训练epoch ,也会产生出色的效果。

您可以按照相同的原则在任何其他平台上使用任何其他预训练的网络执行迁移学习。本文随机挑选了Resnet和pytorch。任何其他CNN都会给出类似的结果。希望这可以节省您使用计算机视觉解决现实世界问题的痛苦时间。

机器学习 人工智能 计算机
上一篇:如何正确实施人工智能? 下一篇:面部识别技术受争议,人工智能该如何控制?
评论
取消
暂无评论,快去成为第一个评论的人吧

更多资讯推荐

「新基建」下大火的工业智能,问题依旧很多

「新基建」火了。连同 5G、人工智能、物联网等信息数字化基础设施,都成为国家新的发展方向,不仅在这些新领域内的从业者们明确了目标,传统行业对数字化转型的需求也蓄势待发。

赵子潇 ·  2天前
特征工程是啥东东?为何需要实现自动化?

如今人工智能(AI)变得越来越普遍和必要。从防止欺诈、实时异常检测到预测客户流失,企业客户每天都在寻找机器学习(ML)的新应用。ML的底层是什么?这项技术如何进行预测?使AI发挥神奇功效的秘诀又是什么?

布加迪 ·  2天前
AI如何改变人类社会的各种业务模式?

在过去的20年中,一些愤世嫉俗的人一直担心,人工智能(AI)的发展会破坏企业结构,导致大量失业和财富不平等加剧。下一个十年将是AI的十年。我们期望看到什么变化?答案是基本流程的转变和减少。

CDA数据分析师 ·  3天前
新冠疫情动态:十大创新,助力对抗COVID-19

从感染快速检测到3D打印解决方案,全球各地的科技企业正携手奋进,希望找到足以战胜新冠病毒大流行的突破性方法。目前有哪些创新成果值得关注?本文将带大家一探究竟。

佚名 ·  3天前
全球首个翻译引擎进化归来 “细节狂魔”搞定方言

最近,一款在线机器翻译软件在日本大火。这款翻译软件名叫DeepL,大火的原因正是因为它工作太负责了,翻译得太过准确,在日本引起了热议。

刘俊寰 ·  3天前
应用程序管理中的AI/ML用例

基于人工智能的操作 (AIOps) 是人工智能和传统 AM/IM 操作的融合。与所有其他领域一样,AI 将对运营管理产生重大影响。

佚名 ·  3天前
学不动了?麻省理工 CS 和 EE 网课开放了

疫情之下,麻省理工学院校长在 3 月上旬曾发通知,其中提到把本剩余课程全部转移到网上。

佚名 ·  3天前
科学家研发出“读心术”,直接将脑电波翻译成文本,错误率低至3%

美国加州大学旧金山分校的科学家,已经训练出一种算法,可以直接将受试者的脑电波实时翻译成句子,错误率仅为 3% 。

张路 ·  3天前
Copyright©2005-2020 51CTO.COM 版权所有 未经许可 请勿转载