Pytorchのモデルを可視化する

ソフトウェア

 

ディープラーニングの学習モデルを作ったら、可視化したいですよね。

今回は、Pytorchで学習モデルを作った際の可視化方法を紹介します。

使用するライブラリは、「make_dot」です。

 

make_dotを使った可視化

 

以下コードです。

今回はPytorchのResnetを可視化してみます。

import torch
import torchvision.models as models
from torchviz import make_dot
alex = models.AlexNet()

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = alex(x)

dot = make_dot(out)
dot.format = 'png'
dot.render('graph_image')

出力される「graph_image.png」は以下のようになります。

 

 

make_dotへの引数がモデル本体ではなく、モデルで予測した結果という点に注意が必要です。

 

make_dotの実行環境

 

「make_dot」はpipでinstallすることができます。

pip install make_dot

 

また、make_dotを動かすには「graphviz」が必要なので、別途インストールしましょう。

「graphviz」は、まずOSにインストールする必要があります。

windowsへのインストール方法は下記参照

Windows で Graphviz のインストール

 

続いて、pythonのgraphvizラッパーをインストールします。

pip install graphviz

 

※condaの仮想環境にて、上記コマンドでインストールしても上手く動作しないことがあります。

その場合は、condaコマンドでインストールしてみましょう。

conda install python-graphviz

 

まとめ

 

Pytorchで作成した学習モデルをmake_dotで可視化する方法を紹介しました。

学習モデルのレイヤー設計レビューなどで使える手法だと思うので、是非お試しあれ。

 

以上!

コメント