博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
day10决策树和随机森林实践
阅读量:6791 次
发布时间:2019-06-26

本文共 3890 字,大约阅读时间需要 12 分钟。

 

 

 10.1

 

注:左半边为欠拟合,右半边过拟合,,此图的主要目的是告诉我们如何调参。 

#!/usr/bin/python# -*- coding:utf-8 -*-import numpy as npimport matplotlib.pyplot as pltimport matplotlib as mplfrom sklearn import treefrom sklearn.tree import DecisionTreeClassifierfrom sklearn.model_selection import train_test_splitfrom sklearn.preprocessing import StandardScalerfrom sklearn.pipeline import Pipelinedef iris_type(s):    it = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2}    return it[s]# 花萼长度、花萼宽度,花瓣长度,花瓣宽度# iris_feature = 'sepal length', 'sepal width', 'petal length', 'petal width'iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'if __name__ == "__main__":    mpl.rcParams['font.sans-serif'] = [u'SimHei']    mpl.rcParams['axes.unicode_minus'] = False    # path = '..\\8.Regression\\8.iris.data'  # 数据文件路径    path = '8.iris.data'  # 数据文件路径    data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})    x, y = np.split(data, (4,), axis=1)    # 为了可视化,仅使用前两列特征    x = x[:, :2]    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)    #ss = StandardScaler()    #ss = ss.fit(x_train)    # 决策树参数估计    # min_samples_split = 10:如果该结点包含的样本数目大于10,则(有可能)对其分支    # min_samples_leaf = 10:若将某结点分支后,得到的每个子结点样本数目都大于10,则完成分支;否则,不进行分支    model = Pipeline([        ('ss', StandardScaler()),        ('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))])    # clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)    model = model.fit(x_train, y_train)    y_test_hat = model.predict(x_test)      # 测试数据    # 保存    # dot -Tpng -o 1.png 1.dot    f = open('.\\iris_tree.dot', 'w')    tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)    # 画图    N, M = 100, 100  # 横纵各采样多少个值    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()  # 第0列的范围    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()  # 第1列的范围    t1 = np.linspace(x1_min, x1_max, N)    t2 = np.linspace(x2_min, x2_max, M)    x1, x2 = np.meshgrid(t1, t2)  # 生成网格采样点    x_show = np.stack((x1.flat, x2.flat), axis=1)  # 测试点    # # 无意义,只是为了凑另外两个维度    # # 打开该注释前,确保注释掉x = x[:, :2]    # x3 = np.ones(x1.size) * np.average(x[:, 2])    # x4 = np.ones(x1.size) * np.average(x[:, 3])    # x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1)  # 测试点    cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])    cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])    y_show_hat = model.predict(x_show)  # 预测值    y_show_hat = y_show_hat.reshape(x1.shape)  # 使之与输入的形状相同    plt.figure(facecolor='w')    plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light)  # 预测值的显示    plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o')  # 测试数据    plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark)  # 全部数据    plt.xlabel(iris_feature[0], fontsize=15)    plt.ylabel(iris_feature[1], fontsize=15)    plt.xlim(x1_min, x1_max)    plt.ylim(x2_min, x2_max)    plt.grid(True)    plt.title(u'鸢尾花数据的决策树分类', fontsize=17)    plt.show()    # 训练集上的预测结果    y_test = y_test.reshape(-1)    print(y_test_hat)    print(y_test)    result = (y_test_hat == y_test)   # True则预测正确,False则预测错误    acc = np.mean(result)    print('准确度: %.2f%%' % (100 * acc))    # 过拟合:错误率    depth = np.arange(1, 15)    err_list = []    for d in depth:        clf = DecisionTreeClassifier(criterion='entropy', max_depth=d)        clf = clf.fit(x_train, y_train)        y_test_hat = clf.predict(x_test)  # 测试数据        result = (y_test_hat == y_test)  # True则预测正确,False则预测错误        err = 1 - np.mean(result)        err_list.append(err)        print(d, ' 准确度: %.2f%%' % (100 * err))    plt.figure(facecolor='w')    plt.plot(depth, err_list, 'ro-', lw=2)    plt.xlabel(u'决策树深度', fontsize=15)    plt.ylabel(u'错误率', fontsize=15)    plt.title(u'决策树深度与过拟合', fontsize=17)    plt.grid(True)    plt.show()
代码

10.2

 

转载于:https://www.cnblogs.com/chengxiaofeng/p/10811281.html

你可能感兴趣的文章
GPU-BASED PROCEDURAL PLACEMENT IN HORIZON ZERO DAWN
查看>>
mysql中[Err] 1366 - Incorrect string value: '\xE5\x8D\x问题
查看>>
Hadoop生态上几个技术的关系与区别:hive、pig、hbase 关系与区别
查看>>
Mysql用户管理(远程连接、授权)
查看>>
Coursera机器学习编程作业Python实现(Andrew Ng)—— 2.1 Logistic Regression
查看>>
前台动态增加行,并将结果打印到XML文件
查看>>
简单回溯,最少步数
查看>>
LeetCode – Refresh – Palindrome Partitioning II
查看>>
mysql线上数据库单表超过200G的处理
查看>>
生成静态页相关
查看>>
OC中ARC forbids explicit message send of release错误
查看>>
J2SE 学习记录
查看>>
VS静态编译
查看>>
个人作业——Alpha项目测试
查看>>
laravel之laravel-admin安装
查看>>
浅谈C#中的接口和抽象类
查看>>
Jmeter实现webservice的接口测试
查看>>
jmeter用BeanShell调用jar包对HTTP请求中的参数进行MD5加密
查看>>
判断页数及切换
查看>>
GraphQL ---02 GraphQL和C#结合的实战项目
查看>>