DENXブログ

同志社大学 電気情報研究会(DENX)の活動や日常風景などを紹介します。

深層学習で線画着色してみた!

こんにちは!DENXでは数少ない機械学習をしている京田辺二回のryです。
前回の「深層学習で妹Botを作ってみた!」に引き続き二回目のブログになります。
今回は深層学習で線画着色に挑戦してみました!
f:id:denx:20190601084144j:plain:w400

本題

線画着色をやる前は主に自然言語処理辺りをやっていて画像処理の分野はほとんどやっていませんでした。
そろそろ新しい分野に手を出そうと思って音声処理か画像処理かとなった時、比較的簡単にデータセットを集めることが出来る画像処理をやろうと思いました。
CycleGANで表情変換したことについても書きたかったんですが、学習が上手くいきませんでした。

全体の流れ

1  データセット集め
2  前処理
3  学習
4  結果
5  まとめ

データセット集め

データセットを集めるにあたって利用したサイトは「Danbooru」です。
髪の色や表情でタグ付けされているので表情変換や髪色変換をしたい場合にも最適です。
スクレイピングをする際、前回と同様にpythonのBeautifulSoupとrequestsを使いました。
requestsを使って取得したhtmlから画像のurlはdata-file-urlの後にあることが分かったので、BeautifulSoupを使ってdata-file-urlの後にある画像のurlを取り出しurlから画像を保存します。
そして、そのページ内のすべての画像を取得し終えると次のページに飛ぶようにします。
約一日集めて50000枚の画像が集まりました。

前処理

取得したイラストを学習できる形にします。
顔の着色をするために顔の部分だけを切り出さないといけません。
lbpcascade_animeface.xmlOpenCVを使ってイラストから顔を切り出してから画像を128x128にリサイズしました。
学習にはカラーの画像だけではなく線画の画像も必要です。
カラーの画像を線画にする際にはこのサイトを参考にしました。
これで線画とカラーのデータセットが出来ました。

学習

今回学習に使ったモデルはU-NETです。
f:id:denx:20190530212801p:plain Chainerを使って学習するモデルを組み立てていきます。
unpooling2x2をDeconvolutionに置き替えました。
コードは以下の通りです。

import chainer
import chainer.links as L
import chainer.functions as F
class U_NET(chainer.Chain):
    def __init__(self):
        super(U_NET,self).__init__()
        with self.init_scope():
            w = chainer.initializers.Normal(0.02)
            #ヒントを教えたい場合は線画(1ch)+カラー画像のヒント(3ch)の4chを入力にしてください。
            self.c1=L.Convolution2D(1,32,ksize=3,stride=1,pad=1,initialW=w)
            self.c2=L.Convolution2D(32,32,ksize=3,stride=1,pad=1,initialW=w)
            self.bn1=L.BatchNormalization(32)
            self.bn2=L.BatchNormalization(32)
            self.c3=L.Convolution2D(32,64,ksize=3,stride=1,pad=1,initialW=w)
            self.c4=L.Convolution2D(64,64,ksize=3,stride=1,pad=1,initialW=w)
            self.bn3=L.BatchNormalization(64)
            self.bn4=L.BatchNormalization(64)
            self.c5=L.Convolution2D(64,128,ksize=3,stride=1,pad=1,initialW=w)
            self.c6=L.Convolution2D(128,128,ksize=3,stride=1,pad=1,initialW=w)
            self.bn5=L.BatchNormalization(128)
            self.bn6=L.BatchNormalization(128)
            self.c7=L.Convolution2D(128,256,ksize=3,stride=1,pad=1,initialW=w)
            self.c8=L.Convolution2D(256,256,ksize=3,stride=1,pad=1,initialW=w)
            self.bn7=L.BatchNormalization(256)
            self.bn8=L.BatchNormalization(256)
            self.c9=L.Convolution2D(256,512,ksize=3,stride=1,pad=1,initialW=w)
            self.c10=L.Convolution2D(512,512,ksize=3,stride=1,pad=1,initialW=w)
            self.bn9=L.BatchNormalization(512)
            self.bn10=L.BatchNormalization(512)
            self.dc1=L.Deconvolution2D(512,512,ksize=2,stride=2,pad=0,initialW=w)
            self.c11=L.Convolution2D(768,256,ksize=3,stride=1,pad=1,initialW=w)
            self.c12=L.Convolution2D(256,256,ksize=3,stride=1,pad=1,initialW=w)
            self.bn11=L.BatchNormalization(256)
            self.bn12=L.BatchNormalization(256)
            self.dc2=L.Deconvolution2D(256,256,ksize=2,stride=2,pad=0,initialW=w)
            self.c13=L.Convolution2D(384,128,ksize=3,stride=1,pad=1,initialW=w)
            self.c14=L.Convolution2D(128,128,ksize=3,stride=1,pad=1,initialW=w)
            self.bn13=L.BatchNormalization(128)
            self.bn14=L.BatchNormalization(128)
            self.dc3=L.Deconvolution2D(128,128,ksize=2,stride=2,pad=0,initialW=w)
            self.c15=L.Convolution2D(192,64,ksize=3,stride=1,pad=1,initialW=w)
            self.c16=L.Convolution2D(64,64,ksize=3,stride=1,pad=1,initialW=w)
            self.bn15=L.BatchNormalization(64)
            self.bn16=L.BatchNormalization(64)
            self.dc4=L.Deconvolution2D(64,64,ksize=2,stride=2,pad=0,initialW=w)
            self.c17=L.Convolution2D(96,32,ksize=3,stride=1,pad=1,initialW=w)
            self.c18=L.Convolution2D(32,32,ksize=3,stride=1,pad=1,initialW=w)
            self.bn17=L.BatchNormalization(32)
            self.bn18=L.BatchNormalization(32)
            self.c19=L.Convolution2D(32,3,ksize=1,stride=1,pad=0,initialW=w)
    def __call__(self,x):
        h1=F.relu(self.bn1(self.c1(x)))
        h2=F.relu(self.bn2(self.c2(h1)))
        m1=F.max_pooling_2d(h2,ksize=2,stride=2)
        h3=F.relu(self.bn3(self.c3(m1)))
        h4=F.relu(self.bn4(self.c4(h3)))
        m2=F.max_pooling_2d(h4,ksize=2,stride=2)
        h5=F.relu(self.bn5(self.c5(m2)))
        h6=F.relu(self.bn6(self.c6(h5)))
        m3=F.max_pooling_2d(h6,ksize=2,stride=2)
        h7=F.relu(self.bn7(self.c7(m3)))
        h8=F.relu(self.bn8(self.c8(h7)))
        m4=F.max_pooling_2d(h8,ksize=2,stride=2)
        h9=F.relu(self.bn9(self.c9(m4)))
        h10=F.relu(self.bn10(self.c10(h9)))
        u1=self.dc1(h10)
        h11=F.relu(self.bn11(self.c11(F.concat([h8,u1]))))
        h12=F.relu(self.bn12(self.c12(h11)))
        u2=self.dc2(h12)
        h13=F.relu(self.bn13(self.c13(F.concat([h6,u2]))))
        h14=F.relu(self.bn14(self.c14(h13)))
        u3=self.dc3(h14)
        h15=F.relu(self.bn15(self.c15(F.concat([h4,u3]))))
        h16=F.relu(self.bn16(self.c16(h15)))
        u4=self.dc4(h16)
        h17=F.relu(self.bn17(self.c17(F.concat([h2,u4]))))
        h18=F.relu(self.bn18(self.c18(h17)))
        h19=self.c19(h18)
        return h19 

これで学習させるモデルを定義出来ました!
データセットの読み込みはImageDatasetで画像を読み込んでTupleDatasetでまとめました。
画像を読み込む際に0から1に正規化しています。

    IMG_DIR_C="./color"
    IMG_DIR_S="./senga"
    img_c=os.listdir(IMG_DIR_C)
    img_s=os.listdir(IMG_DIR_S)
    img_c = datasets.ImageDataset(paths=img_c,root=IMG_DIR_C)
    img_c = datasets.TransformDataset(img_c, lambda x: x / 255.)
    img_s = datasets.ImageDataset(paths=img_s,root=IMG_DIR_S)
    img_s=datasets.TransformDataset(img_s, lambda x: x / 255.)
    dataset=datasets.TupleDataset(img_c,img_s)
    train_iter=chainer.iterators.SerialIterator(dataset,batch_size=args.batch_size)

TupleDatasetで読み込んだ場合はchainer.dataset.concat_examplesで線画とカラー画像を取り出します。
損失関数には教師データと線画を着色した画像の平均二乗誤差を取ったものを損失としました。
最後に学習方法やモデルの保存間隔を決めました。
今回もBOXのGTX 1080 Tiを使って学習させました。

結果

教師データに含まれない画像を着色させてみました。

10000iteration

f:id:denx:20190601203748j:plain:w400 f:id:denx:20190601203807j:plain:w400 f:id:denx:20190601203828j:plain:w400 f:id:denx:20190601203852j:plain:w400 f:id:denx:20190601203902j:plain:w400 f:id:denx:20190601203913j:plain:w400
青みが強いですね…

最終的にはこうなりました! (120000iteration)

f:id:denx:20190601203046j:plain:w400 f:id:denx:20190601203106j:plain:w400 f:id:denx:20190601203122j:plain:w400 f:id:denx:20190601203114j:plain:w400 f:id:denx:20190601084144j:plain:w400 f:id:denx:20190601203132j:plain:w400

まとめ

今回は「深層学習で線画着色をしてみた!」について書かせていただきました。
PaintsChainerやstyle2paintsには到底及びませんがそれなりの結果が出てよかったです。
機械学習に興味がある人はぜひDENXへ!
最後までお読みいただきありがとうございます!