【Python机器学习】KNN进行水果分类和分类器实战(附源码和数据集)

showswoller 2024-06-13 11:01:07 阅读 81

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

KNN算法简介

KNN(K-Nearest Neighbor)算法是机器学习算法中最基础、最简单的算法之一。它既能用于分类,也能用于回归。KNN通过测量不同特征值之间的距离来进行分类。

KNN算法的思想非常简单:对于任意n维输入向量,分别对应于特征空间中的一个点,输出为该特征向量所对应的类别标签或预测值。

KNN算法是一种非常特别的机器学习算法,因为它没有一般意义上的学习过程。它的工作原理是利用训练数据对特征向量空间进行划分,并将划分结果作为最终算法模型。存在一个样本数据集合,也称作训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。

输入没有标签的数据后,将这个没有标签的数据的每个特征与样本集中的数据对应的特征进行比较,然后提取样本中特征最相近的数据(最近邻)的分类标签。

一般而言,我们只选择样本数据集中前k个最相似的数据,这就是KNN算法中K的由来,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的类别,作为新数据的分类。

KNN分类算法的分类预测过程十分简单并容易理解:对于一个需要预测的输入向量x,我们只需要在训练数据集中寻找k个与向量x最近的向量的集合,然后把x的类别预测为这k个样本中类别数最多的那一类。

KNN算法中只有一个超参数k,k值的确定对KNN算法的预测结果有着至关重要的影响。接下来,我们讨论一下k值大小对算法结果的影响以及一般情况下如何选择k值。

如果k值比较小,相当于我们在较小的领域内训练样本对实例进行预测。这时,算法的近似误差(Approximate Error)会比较小,因为只有与输入实例相近的训练样本才会对预测结果起作用。

但是,它也有明显的缺点:算法的估计误差比较大,预测结果会对近邻点十分敏感,也就是说,如果近邻点是噪声点的话,预测就会出错。因此,k值过小容易导致KNN算法的过拟合。

同理,如果k值选择较大的话,距离较远的训练样本也能够对实例预测结果产生影响。这时候,模型相对比较鲁棒,不会因为个别噪声点对最终预测结果产生影响。但是缺点也十分明显:算法的近邻误差会偏大,距离较远的点(与预测实例不相似)也会同样对预测结果产生影响,使得预测结果产生较大偏差,此时模型容易发生欠拟合。

因此,在实际工程实践中,我们一般采用交叉验证的方式选取k值。通过以上分析可知,一般k值选得比较小,我们会在较小范围内选取k值,同时把测试集上准确率最高的那个确定为最终的算法超参数k。

使用KNN进行水果分类

部分数据如下

预测结果和精确度如下

 

部分代码如下

from sklearn import datasetsfrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.model_selection import train_test_splitimport pandas as pd#导入水果数据并查看数据特征fruit = pd.read_csv('fruit_data.txt','\t')# 获取属性X = fruit.iloc[:,1:]# 获取类别Y = fruit.iloc[:,0].T# 划分成测试集和训练集fruit_train_X,fruit_test_X,fruit_train_y,fruit_test_y=train_test_split(X,Y,test_size=0.2, random_state=0)#分类eighborsClassifier()#对训练集进行训练knn.fit(fruit_train_X, fruit_train_y)#对测试集数据的水果类型进行预测predict_result = knn.predict(fruit_test_X)print('测试集大小:',fruit_test_X.shape)print('真实结果:',fruit_test_y)print('预it_test_y))

 绘制KNN分类器图

分类结果如下 可以看到鸢尾花数据集大致分为三类

 部分代码如下

import numpy as npfrom sklearn import neighbors, datasetsimport matplotlib.pyplot as pltfrom matplotlib.colors import ListedColormap# 建立KNN模型,使用前两个特征iris = datasets.load_iris()irisData = iris.data[:, :2] # Petal length、Petal width特征irisTarget = iris.targetclf = neors.KNeighborsClassifier(5) # K=5clf.fit(irisData, irisTarget)#绘制plot ColorMp = ListedColormap(['#005500', '#00AA00', '#00FF00'])X_min, X_max = irisData[:, 0].min(), irisData[:, 0].max()Y_minlabel = clf.predict(np.c_[X.ravel(), Y.ravel()])label = label.reshape(X.shape) #绘图并显示plt.figure()plt.pcolormesh(X,Y,label,cmap=ColorMp)plt.show()

创作不易 觉得有帮助请点赞关注收藏~~~



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。