如何用Tensorflow object-detection API训练模型,找到圣诞老爷爷?

如何用Tensorflow object-detection API训练模型,找到圣诞老爷爷?

​ 原文作者:Varun Vohra

原文链接:https://towardsdatascience.com/is-santa-claus-real-9b7b9839776c

翻译:Xin Qin

Christmas is coming!你是否在期待圣诞老人和他的礼物呢?你想知道哪里可以找到圣诞老人吗?本文将教会你如何通过Tensorflow object-detection API训练自己的目标检测模型(object detector),来找到圣诞老人。

本文的代码可见于github:

https://github.com/turnerlabs/character-finder

代码产生的模型可被延伸用于抓取其他的动画或者真实人物。

这是一张我们找到的圣诞老人跳舞的动图🎅🏻

数据收集

无论何种机器学习模型,最关键的部分都在于数据。我们要搜集不同种类的圣诞老人形象,包括动画、泥塑动画、扮成圣诞老人的真人,因此用来训练的数据具有多样性。为了收集数据,我们写了一个流处理器(stream-processor),它可以用VLC播放器来对网上的视频进行流处理并从中抓取帧。流处理器从视频中抓取帧时不需要等待视频加载。例如,如果视频现在播放到第2秒,流处理器可以抓取第4或5秒的帧。此外,你还可以用ASCII来看视频,是不是特别酷🤓

如何应用流处理器可参考:

https://github.com/turnerlabs/stream-processor

 

下面是我们收集到的一小部分不同的圣诞老人形象。这些或动画或真人装扮的圣诞老人均收集自YouTube。

标记数据

接着我们需要标记数据,例如,给圣诞老人的脸部加一个边框(boundingbox)。给图像数据做标记,我们通常使用的工具是labelimg,但是这里我们采用了(https://medium.com/alex-attia-blog/the-simpsons-characters-recognition-and-detection-part-2-c44f9d5abf37)这篇文章使用的脚本。

要标记图像,我们先点击人物脸部的左上方,然后点击脸部右下方。如果图像中没有出现人物,双击该图像即可删除。脚本代码可参考:

https://github.com/turnerlabs/character-finder/blob/master/detect_labels.py

创建Tensorflow Record文件

我们将边框选中的信息存储为cvs文件,下一步就是将这些cvs文件和图像转换成TFrecord文件,这是Tensorflow目标检测应用程序接口(object-detectionAPI)所使用的文件格式。运行转换的脚本可参考:

https://github.com/turnerlabs/character-finder/blob/master/object_detection/create_characters_tf_record.py

 

我们还需要一个protobuf(Google Protocol Buffer)文本文件来将标记名转换成数字ID。这里我们只需要设置一个类。

创建Config文件

我们以faster_rcnn_inception_resnet config文件为基础训练数据。我们将config文件中类参数的数量改为1,因为我们只有一个类——“圣诞老人(santa)”。我们还将输入路径参数改为之前建立的TFrecord中的参数。我们为faster_rcnn_inception_resnet使用预训练(pre-trained)的checkpoint文件。使用该模型的原因是模型的精确性比模型训练的速度更重要。其他具有不同训练速度和精确度的模型可参考:

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

训练

我们先在自己的计算机上运行训练代码,来检测是否存在问题。如果没有问题,我们在Google云平台的ML engine(https://cloud.google.com/ml-engine/)上运行代码。训练模型要经过超过100,000步。

这个模型无论对动画图片还是真实图片都能很好地运行。

产出模型

训练之后产出的模型要被用来测试不同的图像。因此我们选择训练中得到的最新的checkpoint,将其输出为一个静态的参考图片。将checkpoint转换为静态参考图片的脚本可见:

https://github.com/turnerlabs/character-finder/blob/master/object_detection/export_inference_graph.py

我们还为从Google搜索里返回图像和在返回的图像中找到圣诞老人的模型建立了一个网页。这个网页的所有结果经过筛选之后展示出来的方框的可靠度超过60%。这是网页的一个截图:

下一步

运行训练的过程中,我们观察到TotalLoss值很快地下降到1以下。这意味着模型能够有效地抓取到圣诞老人。

我们知道模型不可能尽善尽美。尽管它在找圣诞老人方面已经能做到相当精确,但有时也会作出错误的预测。正如图所示,像这样的图像模型会误判其包含圣诞老人。

因此模型还存在很大的提升空间。

我们下一步需要在config文件中探索更多不同的参数,以及这些参数如何影响模型的训练和预测结果。

现在,你可以试着用你自己的数据来训练目标检测模型了。