2017年7月21日金曜日

Skleranを使ってみる3

前回の『Sklearnを使ってみる2』を引き続き読み進めていく。

本の方では、『Setosaとそれ以外という分類は簡単にできたが、ではSetosa以外をVirginicaとVirsicolorに分類するのはどうすればいいか、簡単にできそうにはない』と話が進んでいく。

ではどうするかというと、とりあえずの回答として『力技で総当たりで閾値を探す』という方法をとる。

で、まず手始めの準備として、Setosa以外だけから成る、特徴量とラベルを改めて抽出しなおしておく。
(不要なところは#でコメントアウトしてある。)
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
import numpy as np

data = load_iris()

features=data["data"]
feature_names=data["feature_names"]
target=data["target"]
target_names=data["target_names"]
labels=target_names[target]


#for t,marker,c in zip (range(3),'>ox','rgb'):
#    
#    plt.scatter(features[target == t,0],features[target == t,1],marker=marker,c=c)
#
#plt.show()


plength=features[:,2]
is_setosa=(labels=='setosa')

max_setosa=plength[is_setosa].max()
min_non_setosa=plength[~is_setosa].min()

#print('Maximum of setosa:{0}'.format(max_setosa))
#print('Minimum of others:{0}'.format(min_non_setosa))

features=features[~is_setosa]
labels=labels[~is_setosa]
virginica=(labels=="virginica")

最後の三行
features=features[~is_setosa]
labels=labels[~is_setosa]
virginica=(labels=="virginica")
が新たに加わった分。 最初の二行も前回やったブールインデックス参照。

is_setosaがsetosaのところがTrueでそれ以外がFalseの配列なので、それをTrueとFalseを反転させている。

なので~is_setosaは、setosa以外がTrueになっていて、setosaがFalseになっています。
で、ブールインデックス参照を使って、配列からsetosa以外に該当するものを抽出している。

三行目は二行目で抽出したsetosa以外のラベルから、特にvirginicaに一致するものだけをさらに抽出し、ブーリアン配列を生成する。

で、本題の閾値を探す部分のコードが以下。
#初期値の値は処理に影響がないように適当に決める。
とりあえず絶対に影響がないような値として負の値で設定しています。
best_acc=-1.0       #正解率の値の初期値
best_fi=-1.0          #閾値となる特徴量のラベルの値の初期値
best_t=-1.0           #具体的な閾値の数値の初期値


for fi in range(features.shape[1]):    #features.shapeの値はこのとき(100,4)となっている。
なのでfeatures.shape[1]は4。つまり特徴量の個数です。全ての特徴量を総当たりで試すという方針。

    thresh=features[:,fi].copy()       #特徴量を一つ固定してそれに対する数値を全部抽出
    thresh.sort()         #上の行で抽出した数値をソートする。

    for t in thresh:       #測定値一つ一つに対して、「それが閾値になりうるか」を総当たりで試す。

        pred=(features[:,fi]>t)      #測定値一つをとりあえず閾値として固定して、
その閾値以上となっている測定値に対して新しいブーリアン配列としている。

        if (labels[pred]=="virginica")!=():     #accがnanになることを排除。詳細は下記で。
        acc=(labels[pred]=="virginica".mean())      

#上の行でつくったブーリアン配列に対して、さらにラベルが"virginica"に一致しているものを抽出。
#つまり、結果としては、測定値が閾値以上でかつvirginicaであるものを抽出してその平均をとっている。
        if acc > best_acc:    #よりよい結果を出した「特徴量と閾値の組み合わせ」に更新する処理。
          best_acc=acc      #もし、よりいい正解率が出たのなら、よりその時点での最良正解率としてそれに更新
          best_fi=fi             #もし、よりいい正解率がでたのなら、そのときの特徴量を採用。
          best_t=t              #もし、よりいい正解率がでたのなら、そのときの閾値を採用。


本文とは違い、if(labels[pred]=="virginica")!=():  という行を追加していますが、これなしだと、accがnanになってしまうことがあるので追加している。

最後に以下を追加して一段落。
def apply_model(example):
   if example[best_fi] > best_t : print("virginica")
   else:print("versicolor")

0 件のコメント:

コメントを投稿