もっちさんの明日はどっちだ

あした、なに観て 生きていく?

【スポンサーリンク】

MNIST で chainer 2.0 の仕様をお試し

普段 chainer を使っているんですが、最近 ver 2.0 に大幅アップデートがあり、自分のコードも2.0に対応するさせるために雛形を作ってみたのでお試し更新。

とりあえず MNIST をやるとして、trainerで抽象化させすぎると使いにくいので、custom loopをベースに自分がハンドリングしやすいようにアレンジしてみる。

https://github.com/chainer/chainer/tree/master/examples/mnist

変更点

以下を参照してみると

Upgrade Guide from v1 to v2 — Chainer 2.0.0 documentation

僕のソースコードだと chainer 2.0 のクリティカルなアップデートは、dropout などの train フラグと Variable の volatile の廃止。

とりあえずネットワークは dropout を噛ませた以下にしてみる。

class MTNNet(chainer.Chain):
    def __init__(self, n_mid, n_out):
        super(MTNNet, self).__init__()
        with self.init_scope():
            self.lin1 = L.Linear(None, n_mid)
            self.lin2 = L.Linear(None, n_out)
        
    def __call__(self, x):
        h1 = self.lin1(x)
        h2 = F.relu(h1)
        h3 = F.dropout(h2)
        
        y = self.lin2(h3)
        return y
    
    def loss(self, x, t):
        y = self(x)
        loss = F.softmax_cross_entropy(y,t)
        self.accuracy = F.accuracy(y,t)
        return y, loss

config でフラグが管理されており、上述の2点について chainer.config.enable_backprop、chainer.config.train がデフォルトで True になっている。つまり Variable の volatile=False、train=True の状態なので、学習時は気にせず実行すればいいらしい。 test 時は with でフラグを False にしてやる(抜けると元に戻る)。

with chainer.using_config('train', False), chainer.no_backprop_mode():
    y, loss = self.model.loss(x, t)

こうすれば、volatile=True, train=False で実行できる。

古いソースコードをコピペで使っていたので、custom loop でも iterators をうまく使えばすっきりするのかとサンプルを見ながら。とりあえず止まらなければいいやでserializers とかを try で投げてたりするんですが、お気になさらず。

github.com

上がソースコード。 これを雛形にして今後遊べそうだなと思ったところで、今回はこんな感じ。

深層学習 (機械学習プロフェッショナルシリーズ)

深層学習 (機械学習プロフェッショナルシリーズ)

【スポンサーリンク】