アルパカ三銃士

〜アルパカに酔いしれる獣たちへ捧げる〜

AI::MxNet の char_lstm.pl を触ってみた

前々から LSTM に興味があった(何かしらの文章を生成させてみたいと思ってた)ため、今回 AI::MxNet の example に含まれている char_lstm.pl を触ってみた。 github.com

LSTM については以下の記事が丁寧に解説している。 s0sem0y.hatenablog.com

このサンプルはシェイクスピア風の文章を生成するといったサンプル。学習用に使用されるデータは「The Tragedy of Coriolanus」である。サンプルのコードはパッと見た感じ難しく感じるが、ちゃんと読み進めていくと大したことはないため、ここで AI::MxNet に関連する大事な部分のみ解説してみる。

Cell に関してのドキュメントは metacpan公式のページを照らし合わせながら読むと良く分かる。
まずは 175 行目にあるこれ。

my $stack = mx->rnn->SequentialRNNCell();

これを使うことによって複数の cell をスタック形式で保持することができる。

$stack->reset;

とすることによって直近でグラフ描写?に使われた cell をリセットすることができるらしい。
スタックに貯めた複数の cell 全体を一つの cell として扱うことができ、これらを用いることでモデルの学習パフォーマンスや予測値が向上するとのこと。 SequentialRNNCell へ cell を追加するにはこのサンプル場合だと

my $cell = mx->rnn->$mode(num_hidden => $num_hidden, prefix => "lstm_${i}l0_");
if($bidirectional)
{
    $cell = mx->rnn->BidirectionalCell(         
        $cell,
        mx->rnn->$mode(
            num_hidden => $num_hidden,
            prefix => "lstm_${i}r0_"
        ),
        output_prefix => "bi_lstm_$i"
    );
}
$stack->add($cell); # cell の追加

が相当する。ちなみに mx->rnn->$mode はオプションで値を指定しない限り mx->rnn->LSTMCell と等価である。

ここからネットワークの定義を行っていく。
それがサンプルコード内の以下の部分になる。

my $data  = mx->sym->Variable('data');
my $label = mx->sym->Variable('softmax_label');
my $embed = mx->sym->Embedding(
        data => $data, input_dim => scalar(keys %vocabulary),
        output_dim => $num_embed, name => 'embed'
);
$stack->reset;
my ($outputs, $states) = $stack->unroll($seq_size, inputs => $embed, merge_outputs => 1);
my $pred  = mx->sym->Reshape($outputs, shape => [-1, $num_hidden*(1+($bidirectional ? 1 : 0))]);
$pred     = mx->sym->FullyConnected(data => $pred, num_hidden => $data_iter->vocab_size, name => 'pred');
$label    = mx->sym->Reshape($label, shape => [-1]);
my $net   = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');

これらのコードについて以下の qiita の記事を読むとわかりやすい。 qiita.com

こうやって読んでいくと MxNet を用いると簡単にネットワーク層の設計ができるため、非常に良い。

最後にモデルの定義と学習を行うコード。 $model->fit で学習を行っている。 この定義時に使われている context はデバイスに関する情報を渡している。AI::MXNet::Context の metacpan を読むと簡単に書いているのですぐに分かる。

my $model = mx->mod->Module(
    symbol  => $net,
    context => $contexts
);
$model->fit(
    $data_iter,
    eval_metric         => mx->metric->Perplexity,
    kvstore             => $kv_store,
    optimizer           => $optimizer,
    optimizer_params    => {
                                learning_rate => $lr,
                                momentum      => $mom,
                                wd            => $wd,
                                clip_gradient => 5,
                                rescale_grad  => 1/$batch_size,
                                lr_scheduler  => AI::MXNet::FactorScheduler->new(step => 1000, factor => 0.99)
                        },
    initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
    num_epoch           => $num_epoch,
    batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
    ($chkp_epoch ? (epoch_end_callback  => [mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch), \&sample]) : ())
);

これらのそれぞれの意味は AI::MXNet::Module::Base の metacpan公式に書かれているのを読むとそれぞれのパラメータについて分かりやすい上に、どのようなメソッドが使えるのかも理解できる。

今回感じたこととして、AI::MxNet だけのドキュメントだけを読んで解決しようとしてもそれぞれのメソッドの意味が全く分からなくて辛い。そのため、公式の Python 用のドキュメントも一緒に合わせて読むことで理解を含めることができた。
ちなみにこのサンプルを実行して 1 時間半経過して、この記事を書いていたがそれでもまだ学習が終了しない…

ドキュメント一覧