kerasで作った学習済みモデルによる推論をwebサービスにしようとしたときの備忘録です。
結果的に、webサービス化に失敗したので、検討している方の参考になればと思います。
やりたかったこと
webサービスでやりたかったことは、
- 推論サーバに学習済みモデルを追加/削除する
- 推論サーバに画像を渡し、登録済みモデルを用いて推論をする
でした。
できなかったこと
学習済みモデルを複数登録した場合、1つだけ削除したいケースがあると思います。
学習済みモデルをGPUメモリ上に複数展開した場合、そのうち1つだけ削除する(つまり、GPUメモリ上から削除する)ことができませんでした。
そしてもう一つ、
マルチスレッドでの学習済みモデルのGPUメモリ展開ができませんでした。
途中まで実装したコード
途中まで実装した結果を載せます。
色々な制約があるため、加工したソースを載せます。そのままでは動きませんので、ご注意を。
WebAPI
flaskを使ったシンプルな実装です。
APIはset_modelとpredictのみで、学習済みモデルにIDを付与して識別します。
検討段階の実装のため、学習済みモデル及び推論対象の画像はファイルパスを渡し、サーバ側で読み込みをしていました。
学習済みモデルの管理はKerasModelManagerというクラスを作成し、そこで行っています。
本来はset_modelにDELETEメソッドも追加したかったのですが、複数モデルインポートしている状態で1つのモデルを削除することができなかったので、実装していません。
app = Flask(__name__)
@app.route('/api/model/', methods=['PUT'])
def set_model():
mng = KerasModelManager()
model_id = request.form.get('model_id')
model_path = request.form.get('model_path')
mng.set_model(model_id, model_path)
return
@app.route('/api/image/', methods=['POST'])
def predict():
mng = KerasModelManager()
model_id = request.form.get('model_id')
image = request.form.getlist('image')
result = mng.predict(model_id, image)
return result
学習済みモデル管理クラス
学習済みモデルを管理するクラスであるKerasModelManagerは、学習済みモデルのIDと学習済みモデルを保持するKerasModelクラスのインスタンスを保持するdictをclass変数に持ちます。
class変数にすることで、KerasModelManagerのインスタンスが複数生成されても、モデルを一意に管理できます。
学習済みモデルをインポートするset_model関数は、セマフォを使った排他処理になっています。
これはマルチスレッドでの学習済みモデルのインポートが不可能なためです。
また、モデルをインポートした後、ダミー画像の推論をしています。
理由は不明ですが、この処理がないとマルチスレッドでの推論が正常に動作しませんでした。
class KerasModelManager: model_list = {} def __init__(self): self.model_semaphore = threading.Semaphore(1) self.dummy_image_path = 'dummy.jpg' def set_model(self, model_path): with self.model_semaphore: model = KerasModel(model_path) model.predict(self.dummy_image_path) self.model_list[model_path] = model def predict(self, model_id, image): model = self.model_list.get(model_id) return vae.predict(image)
学習済みモデルクラス
学習済みモデルクラスのKerasModelですが、グローバル変数にTensorFlowのグラフを宣言し、学習済みモデルのインポート及び学習済みモデルを使った推論時にwith構文で使用します。
こうしないと複数アクセス時に排他制御がうまくいかず、エラーが発生します。
import tensorflow as tf graph = tf.get_default_graph() class KerasModel: def __init__(self, model_path): global graph with graph.as_default(): self.model = load_model(model_path) def predict(self, img): global graph with graph.as_default(): img = load_img(img) img = img_to_array(img) return self.model.predict(img)
まとめ
kerasの推論をWebサービスで実現しようとした際の備忘録でした。
kerasはモデルの構築および学習に使うのが効果的であって、推論、特にパフォーマンスを求められる推論には向いていないことがあらためてわかりました。
kerasでの推論サーバの実装がうまくいかなかったので、kerasをOpenVINOに変換して実装した例を別の記事で紹介しようと思います。
【2019.09.22追記】
kerasをOpenVINOに変換するためにはまず.pbファイルに変換する必要がありました。
.pbファイルへ変換した際のコードなどを記事にまとめたので紹介します。

以上!
コメント