[Python][Numpy] 組み込みのboolとNumpy.bool_は別物

Python組み込みのbool型とnumpy.bool_型は別物なので、is演算子で比較する場合は注意が必要、という話。

環境:Windows10 + Python 3.7 + Numpy 1.18.1

numpy.bool_型とPython組み込みのbool型は別物

Numpyを使って真偽値を判定すると、numpy.bool_という型になって返ってくるが、実はnumpy.bool_Python組み込みのboolとは別の型である。

次のコードを見てもらいたい。

コード1:

import numpy as np

x = np.log(10) > 1  # -> True
print(x, type(x))
print(x is True)
print(x is False)

出力1:

True <class 'numpy.bool_'>
False
False

3行目でxに、np.log(10) > 1の計算結果であるTrueが代入されるのだが、その型は組み込みのbool型ではなくnumpy.bool_というNumpyで実装されているbool型である。

通常、組み込みbool型に対しては、is演算子を使って、TrueかFalseかを判定する。もちろん==演算子を使っても判定はできるのだが、Pythonのコードスタイルではis演算子を使う方法が推奨されている。
www.flake8rules.com

しかし、numpy.bool_のTrueと組み込みbool型のTrueは別物なので、is演算子で比較すると、Falseが返ってきてしまう。つまり、「TrueでもないしFalseでもない」ように見えてしまうのである。困るのは、次のような条件分岐を書いた場合である。

コード2

import numpy as np

x = np.log(10) > 1  # -> True
if x is True:
    print('True!')
else:
    print('False!')

出力2

False!

xにはTrueが代入されているので、'True'がprintされてほしいところ、'False'がprintされてしまう。

バグではない

念のため書いておくと、これはバグではない。そういう仕様である。

github.com

バグではないとは言え、直感に反する挙動をされると困る。そこで解決策をいくつか考えてみる。

解決策1:bool関数で型変換する

Python組み込みのbool関数を使って、numpy.bool_型を組み込みのbool型に変換してやればよい。

コード3

import numpy as np

x = bool(np.log(10) > 1)  # -> True
print(x, type(x))
print(x is True)
print(x is False)

出力3

True <class 'bool'>
True
False

解決策2:==演算子で比較する

is演算子ではなく、==演算子を使って比較すれば、期待通りの結果が得られる。

コード4

import numpy as np

x = np.log(10) > 1  # -> True
print(x, type(x))
print(x == True)
print(x == False)

出力4

True <class 'numpy.bool_'>
True
False

ただし、Pythonの推奨コードスタイルから外れるので、エディタによっては注意が表示されるかもしれない*1。その場合、NOQAコメントを追加すれば、表示を一時的に消すことができる。

コード5

import numpy as np

x = np.log(10) > 1  # -> True
print(x, type(x))
print(x == True)  # NOQA
print(x == False)  # NOQA

出力5
(出力4と同じ。省略)

解決策3:if True:の形で使う

そもそもis演算子を使わずに、条件式をそのまま使えば問題は起こらない。ただし、andorを多用した複雑な条件式の場合は、is演算子を使って分かりやすく書きたくなるので悩ましいところではある。

コード6

import numpy as np

x = np.log(10) > 1  # -> True
if x:
    print('True!')
else:
    print('False!')

出力6

True!

補足:numpy.bool_()関数の挙動

最後に、numpy.bool_()関数の挙動を調べた結果をメモとして残しておく。なぜかnumpy公式ドキュメントには、numpy.bool_関数の説明ページが見当たらなかった。

import numpy as np

print(np.bool_(0))  # -> False
print(np.bool_(1))  # -> True
print(np.bool_(2))  # -> True
print(np.bool_(3))  # -> True

print(np.bool_(True))  # -> True
print(np.bool_(False))  # -> False

*1:私の環境では "comparison to False should be 'if cond is False:' or 'if not cond:'pycodestyle(E712)" という注意が表示された。