from sklearn.feature_extraction import DictVectorizer
import csv
from sklearn import tree
from sklearn import preprocessing
from sklearn.externals.six import StringIO

#Read in the csv file and put features into list of dict and list of class label
allElectronicsData = open(r'AllElectronics.csv', 'rt')
reader = csv.reader(allElectronicsData)
headers = next(reader)

print(headers)

featureList = []
labelList = []

for row in reader:
    # 把所有的结果放到这里,相当于
    labelList.append(row[len(row)-1])
    #存x值,以键值对的形式,键值从headers里面取,属性值从每行数据里面取
    rowDict = {}
    for i in range(1, len(row)-1):
        rowDict[headers[i]] = row[i]
    featureList.append(rowDict)

print(featureList)

#Vetorize features
#0-1化
#说明:DictVectorizer的处理对象是符号化(非数字化)的但是具有一定结构的特征数据,如字典等,将符号转成数字0/1表示。
#我们不难发现,DictVectorizer对非数字化的处理方式是,借助原特征的名称,组合成新的特征,并采用0/1的方式进行量化,
#而数值型的特征转化比较方便,一般情况维持原值即可。
vec = DictVectorizer()
#fit_transform():先拟合数据再标准化
#transform():标准化
dummyX = vec.fit_transform(featureList) .toarray()

print("dummyX: " + str(dummyX))
print(vec.get_feature_names())

print("labelList: " + str(labelList))

#vectorize class labels
#标签二值化:sklearn.preprocessing.LabelBinarizer(neg_label=0, pos_label=1,sparse_output=False)
#主要是将多类标签转化为二值标签,最终返回的是一个二值数组或稀疏矩阵
#参数说明:
#neg_label:输出消极标签值
#pos_label:输出积极标签值
#sparse_output:设置True时,以行压缩格式稀疏矩阵返回,否则返回数组
#classes_属性:类标签的取值组成数组
#①设置neg_label=2、pos_label=4,只能返回二值数组,理解neg_label、pos_label两标签值的含义

lb = preprocessing.LabelBinarizer()
dummyY = lb.fit_transform(labelList)
print("dummyY: " + str(dummyY))

#Using decision tree for classification
#clf = tree.DecisionTreeClassifier()
#决策树分类器
clf = tree.DecisionTreeClassifier(criterion='entropy')
clf = clf.fit(dummyX, dummyY)
print("clf: " + str(clf))

#Visualize model
with open("allElectronicInformationGainOri.dot", 'w') as f:
    f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file=f)

oneRowX = dummyX[0, :]
print("oneRowX: " + str(oneRowX))

newRowX = oneRowX
newRowX[0] = 1
newRowX[2] = 0
print("newRowX: " + str(newRowX))

predictedY = clf.predict(newRowX.reshape(1, -1))
print("predictedY: " + str(predictedY))

数据样本: