"BOKU"のITな日常

BOKUが勉強したり、考えたことを頭の整理を兼ねてまとめてます。

Python3プログラムのユニットテストのついでにボトルネックもチェックする/unittest・profile

f:id:arakan_no_boku:20210912183619p:plain

目次

Pythonのunittestおさらい

今回のメインはProfileです。

ただ、確認にunittestを使うので、ざっくりおさらいしておきます。

Pythonのunittestnoのソースは基本的に以下のような構造です。

import unittest


class TestWithProfile(unittest.TestCase):

    def setUp(self):
        # データ設定
        pass

    def test_func(self):
        # テスト実行
        self.assertTrue(1)


if __name__ == '__main__':
    unittest.main()

setUp()で、テストデータを準備して、test_xxxxのメソッドがテスト実行されます。 

実行は。個別にテストケースを期したファイルを実行してもいいですが、

python -m unittest

とか

python -m unittest dicover

で一括実行できます。

テストの失敗の検査と報告を行うメソッド(アサーションメソッド)は、主に以下のようなものがあります。

メソッド 確認事項
assertEqual(a, b) a == b
assertNotEqual(a, b) a != b
assertTrue(x) bool(x) is True
assertFalse(x) bool(x) is False
assertIs(a, b) a is b
assertIsNot(a, b) a is not b
assertIsNone(x) x is None
assertIsNotNone(x) x is not None
assertIn(a, b) a in b
assertNotIn(a, b) a not in b
assertIsInstance(a, b) isinstance(a, b)
assertNotIsInstance(a, b) not isinstance(a, b)
assertAlmostEqual(a, b) round(a-b, 7) == 0
assertNotAlmostEqual(a, b) round(a-b, 7) != 0
assertGreater(a, b) a > b
assertGreaterEqual(a, b) a >= b
assertLess(a, b) a < b
assertLessEqual(a, b) a <= b
assertRegex(s, r) r.search(s)
assertNotRegex(s, r) not r.search(s)

とりあえず、この程度わかれば、テストは書けます。

Pythonのプロファイラ

Pythonには2つの組み込みプロファイラがあります。

  • Pythonで書かれた「profile」
  • C言語で書かれた「cProfile」

しかし、特別な理由がない限り「cProfile」を使います。

理由は単純でオーバーヘッドが少ないからです。

プロファイラは各メソッドの時間を計測するものなので、プロファイラ自身のオーバーヘッドの影響は少ないほうがいいに決まってます。

使い方は、ほぼ定型的です。(自分の場合は・・ですが)

必要なimportは以下です。

from cProfile import Profile
from pstats import Stats

計測は、Profileオブジェクトを生成して、計測対象のメソッドを「runcall」します。

profiler = Profile()
profiler.runcall(test_func)

計測結果の表示には、psstatsのStatsオブジェクトに計測したprofileを渡して生成して、以下のような手順で表示します。

sortは、cumulative(累積)時間順になります。

stats = Stats(profiler)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_stats()

自分は、これをunittestのテストケースの一つに組み込んで、時間がかかりすぎている部分がないかをおおまかにチェックする感じで使ってます。

効率の悪い処理をわざと実行して計測テスト

試してみます。

遅いソートアルゴリズムバブルソートの何の工夫もしてないもの)を実装して、それで10000件のランダムなリストをソートして、結果が正しいかをチェックするという処理を「z_tmp.py」という名前で作成しました。

def normal_bubble_sort(data):
    for i in range(len(data)):
        for j in range(len(data) - i - 1):    # ソート済みの部分以外でループ
            if data[j] > data[j + 1]:    # 前のほうが大きいとき
                data[j], data[j + 1] = data[j + 1], data[j]
    return data


def check_sorted_data(data):
    sorted_result = normal_bubble_sort(data)
    for i in range(len(sorted_result)):
        if i > 0:
            if sorted_result[i - 1] > sorted_result[i]:
                return False
    return True

これを単体テストとプロファイルするコードはこんな感じになります。

import unittest
from cProfile import Profile
from pstats import Stats
from random import randint
import z_tmp


class TestWithProfile(unittest.TestCase):

    def setUp(self):
        # データ設定
        self.max_size = 10000
        self.data = [randint(0, self.max_size) for _ in range(self.max_size)]

    def test_func(self):
        self.assertTrue(z_tmp.check_sorted_data(self.data))

    def test_profile(self):
        profiler = Profile()
        profiler.runcall(self.test_func)
        stats = Stats(profiler)
        stats.strip_dirs()
        stats.sort_stats('cumulative')
        print("++++++++++++++++++++++++++++++++++++++++++++++")
        stats.print_stats()
        print("++++++++++++++++++++++++++++++++++++++++++++++")
        self.assertTrue(1)


if __name__ == '__main__':
    unittest.main()

実行してみます。

しばらく、だんまりした後で、OKと一緒に以下のようなログが出力されました。

 Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   12.960   12.960 test_with_profile.py:15(test_func)
        1    0.002    0.002   12.960   12.960 z_tmp.py:9(check_sorted_data)
        1   12.956   12.956   12.959   12.959 z_tmp.py:1(normal_bubble_sort)
    10002    0.002    0.000    0.002    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.000    0.000    0.000    0.000 case.py:761(assertTrue)

見事に実行時間のほとんどを、 z_tmp.py:の「normal_bubble_sort)」メソッドが消費していることが示されています。

なので、これを高速化するには、その部分をもっと速いソートアルゴリズムにおきかえてやればいい・・とわかります。

よしよし・・です。