[Python][Numpy] 組み込みのboolとNumpy.bool_は別物
Python組み込みのbool型とnumpy.bool_型は別物なので、is
演算子で比較する場合は注意が必要、という話。
環境:Windows10 + Python 3.7 + Numpy 1.18.1
numpy.bool_型とPython組み込みのbool型は別物
Numpyを使って真偽値を判定すると、numpy.bool_
という型になって返ってくるが、実はnumpy.bool_
はPython組み込みのboolとは別の型である。
次のコードを見てもらいたい。
コード1:
import numpy as np x = np.log(10) > 1 # -> True print(x, type(x)) print(x is True) print(x is False)
出力1:
True <class 'numpy.bool_'> False False
3行目でx
に、np.log(10) > 1
の計算結果であるTrueが代入されるのだが、その型は組み込みのbool型ではなくnumpy.bool_
というNumpyで実装されているbool型である。
通常、組み込みbool型に対しては、is
演算子を使って、TrueかFalseかを判定する。もちろん==
演算子を使っても判定はできるのだが、Pythonのコードスタイルではis
演算子を使う方法が推奨されている。
www.flake8rules.com
しかし、numpy.bool_
のTrueと組み込みbool型のTrueは別物なので、is
演算子で比較すると、Falseが返ってきてしまう。つまり、「TrueでもないしFalseでもない」ように見えてしまうのである。困るのは、次のような条件分岐を書いた場合である。
コード2
import numpy as np x = np.log(10) > 1 # -> True if x is True: print('True!') else: print('False!')
出力2
False!
xにはTrueが代入されているので、'True'がprintされてほしいところ、'False'がprintされてしまう。
解決策1:bool関数で型変換する
Python組み込みのbool関数を使って、numpy.bool_型を組み込みのbool型に変換してやればよい。
コード3
import numpy as np x = bool(np.log(10) > 1) # -> True print(x, type(x)) print(x is True) print(x is False)
出力3
True <class 'bool'> True False
解決策2:==
演算子で比較する
is
演算子ではなく、==
演算子を使って比較すれば、期待通りの結果が得られる。
コード4
import numpy as np x = np.log(10) > 1 # -> True print(x, type(x)) print(x == True) print(x == False)
出力4
True <class 'numpy.bool_'> True False
ただし、Pythonの推奨コードスタイルから外れるので、エディタによっては注意が表示されるかもしれない*1。その場合、NOQA
コメントを追加すれば、表示を一時的に消すことができる。
コード5
import numpy as np x = np.log(10) > 1 # -> True print(x, type(x)) print(x == True) # NOQA print(x == False) # NOQA
出力5
(出力4と同じ。省略)
解決策3:if True:
の形で使う
そもそもis
演算子を使わずに、条件式をそのまま使えば問題は起こらない。ただし、and
やor
を多用した複雑な条件式の場合は、is
演算子を使って分かりやすく書きたくなるので悩ましいところではある。
コード6
import numpy as np x = np.log(10) > 1 # -> True if x: print('True!') else: print('False!')
出力6
True!
補足:numpy.bool_()関数の挙動
最後に、numpy.bool_()関数の挙動を調べた結果をメモとして残しておく。なぜかnumpy公式ドキュメントには、numpy.bool_関数の説明ページが見当たらなかった。
import numpy as np print(np.bool_(0)) # -> False print(np.bool_(1)) # -> True print(np.bool_(2)) # -> True print(np.bool_(3)) # -> True print(np.bool_(True)) # -> True print(np.bool_(False)) # -> False
*1:私の環境では "comparison to False should be 'if cond is False:' or 'if not cond:'pycodestyle(E712)" という注意が表示された。