【Python】sklearn.datasets使用(数据集、常用函数、示例代码)
顽石九变 2024-10-26 08:35:02 阅读 96
文章目录
一、sklearn.datasets数据集二、sklearn.datasets常用函数1. 加载数据集2. 生成数据集3. 图像数据集
三、各数据集的作用1、合成数据集2、真实数据集
四、数据集使用代码示例五、详细使用示例load_digits 完整使用fetch_20newsgroups 完整使用
六、常用函数详细介绍make_regression 生成一个随机的回归问题train_test_split 划分训练集和测试集classification_report 分类报告
参考
一、sklearn.datasets数据集
<code>sklearn.datasets 中包含了多种多样的数据集,这些数据集主要可以分为以下几大类:
玩具数据集(Toy datasets):
波士顿房价(Boston Housing)数据集:包含了506个波士顿地区的房屋数据,每个数据点有13个特征变量和一个目标变量(房屋价格的中位数)。鸢尾花(Iris)数据集:这是一个常用的分类数据集,包含了3种不同的鸢尾花及其4个特征。糖尿病(Diabetes)数据集:关于糖尿病患者的一些生理指标与一年后的疾病进展指数之间的关系。手写数字(Digits)数据集:包含了1797个手写数字的图像数据。还有其他一些小型标准数据集,如乳腺癌(Breast Cancer)数据集等。
真实世界中的数据集(Real-world datasets):
这些数据集通常需要通过fetch_
函数从网络上下载,它们是近年来真实收集的数据,适用于更复杂的机器学习任务。例如,新闻组(20 Newsgroups)数据集,这是一个用于文本分类的大型数据集。
样本生成器(Sample generators):
sklearn.datasets
还提供了一系列函数来生成人工数据集,如make_classification
、make_regression
等。这些函数可以根据用户指定的参数生成用于分类、回归等任务的数据集。
样本图片(Sample images):
sklearn.datasets
也包含了一些用于图像处理和计算机视觉任务的数据集,如Olivetti人脸识别数据集等。
SVMLight或LibSVM格式的数据:
可以加载SVMLight或LibSVM格式的数据集,这些格式常用于机器学习竞赛和研究中。
从OpenML下载的数据:
OpenML是一个用于机器学习数据和实验的公共存储库。通过sklearn.datasets.fetch_openml()
函数,可以从OpenML下载各种数据集。
需要注意的是,sklearn.datasets
中的数据集主要是为了方便教学和入门学习而提供的。在实际应用中,可能需要使用更大规模、更复杂的数据集来训练模型。此外,随着时间的推移,sklearn
库可能会更新和添加新的数据集,因此建议查阅最新的官方文档以获取最准确的信息。
二、sklearn.datasets常用函数
sklearn.datasets
模块提供了多个函数来加载和生成数据集。以下是一些sklearn.datasets
中常用的函数:
1. 加载数据集
load_iris()
: 加载鸢尾花数据集,这是一个常用的多类分类数据集。load_digits()
: 加载手写数字数据集,每个实例都是一张8x8的数字图像及其对应的数字类别。load_boston()
: 加载波士顿房价数据集,这是一个回归问题的数据集。load_breast_cancer()
: 加载乳腺癌数据集,这是一个二分类问题的数据集。load_diabetes()
: 加载糖尿病数据集,这个数据集可以用于回归分析。fetch_20newsgroups(subset='train')code>: 下载和加载20个新闻组文本数据集的一个子集,用于文本分类或聚类。
fetch_openml(name='mnist_784', ...)code>: 从OpenML获取数据集,例如著名的MNIST手写数字数据集。
2. 生成数据集
make_blobs(n_samples=100, ...)
: 生成用于聚类的随机数据。make_classification(n_samples=100, ...)
: 生成一个随机的二分类问题。make_regression(n_samples=100, ...)
: 生成一个随机的回归问题。make_moons(n_samples=100, ...)
: 生成两个交错的月牙形状的数据集,用于分类。make_circles(n_samples=100, ...)
: 生成两个环形的数据集,一个内环,一个外环,也用于分类。make_s_curve(n_samples=100, ...)
: 生成S形曲线数据集,这是一个非线性可分的数据集。
3. 图像数据集
fetch_olivetti_faces()
: 下载Olivetti人脸数据集,通常用于人脸识别或图像处理。fetch_lfw_people(min_faces_per_person=70, ...)
: 下载Labeled Faces in the Wild (LFW) 人脸数据集的一个子集。fetch_lfw_pairs(subset='train', ...)code>: 下载LFW人脸数据集中的成对面部图像,用于验证面部识别算法。
这些函数提供了方便的方式来获取或生成数据,以便在机器学习项目中进行测试和验证。每个函数都有许多参数可以调整,以便根据需要生成或加载特定类型的数据集。在使用这些函数时,请务必查阅官方文档以了解所有可用的参数和选项。
三、各数据集的作用
1、合成数据集
make_blobs:
作用:生成随机的数据集聚类,通常用于测试聚类算法。特点:可以通过调整参数来控制数据点的数量、中心点的数量、标准差等,从而生成具有不同特性的聚类数据。
make_classification:
作用:生成一个随机的二分类问题,用于测试分类算法。特点:可以指定样本数量、特征数量、类别分隔的难易程度等参数,以模拟不同复杂度的分类问题。
make_regression:
作用:生成一个随机的回归问题,用于测试回归算法。特点:可以设定输入特征的数量、噪声水平等,以生成具有不同特性的回归数据。
make_moons:
作用:生成两个交错的半圆形数据,通常用于测试分类算法的性能,特别是在处理非线性可分数据时的表现。特点:可以通过调整噪声水平、两个半圆的间距等参数来增加分类问题的难度。
make_circles:
作用:生成两个环形的数据集,一个在内环,一个在外环,也常用于测试分类算法处理非线性数据的能力。特点:与make_moons
类似,这也是一个用于测试非线性分类问题的数据集生成器。
2、真实数据集
load_iris:
作用:提供鸢尾花数据集,这是一个经典的机器学习数据集,包含三种不同类型的鸢尾花(山鸢尾、变色鸢尾、维吉尼亚鸢尾)及其四个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。应用场景:常用于多类分类问题的实验和演示。
load_digits:
作用:提供手写数字(0-9)的数据集,每个样本都是8x8的图像。这是一个用于图像识别和机器学习入门的数据集。应用场景:常用于图像分类、特征提取等任务的实验和演示。
load_boston:
作用:提供波士顿房价数据集,包含波士顿郊区房屋的中位数价格以及各项与房价相关的特征(如犯罪率、住宅平均房间数、是否靠近查尔斯河等)。应用场景:主要用于回归问题的实验和演示,特别是房价预测等经济分析领域。
load_breast_cancer:
作用:提供乳腺癌数据集,包含良性和恶性两种类别的乳腺肿块样本及其相关特征(如肿块半径、纹理、周长等)。应用场景:常用于医学领域的分类问题实验和演示,特别是癌症检测和诊断。
fetch_20newsgroups:
作用:提供一个文本分类的数据集,包含来自20个不同新闻组的文档(如计算机、科学、政治、体育等主题)。这是一个用于自然语言处理和文本分类任务的数据集。应用场景:常用于文本挖掘、主题建模、情感分析等领域的实验和演示。
四、数据集使用代码示例
以下是sklearn.datasets
中几个数据集的示例使用代码:
加载鸢尾花(Iris)数据集
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
print(f"特征数量: { X.shape[1]}")
print(f"类别数量: { len(set(y))}")
加载手写数字(Digits)数据集
from sklearn.datasets import load_digits
digits = load_digits()
X, y = digits.data, digits.target
print(f"图像数量: { len(X)}")
print(f"每张图像的特征数量: { X[0].shape[0]}")
加载波士顿房价(Boston)数据集
from sklearn.datasets import load_boston
boston = load_boston()
X, y = boston.data, boston.target
print(f"特征数量: { X.shape[1]}")
print(f"房价样本数量: { len(y)}")
加载乳腺癌(Breast Cancer)数据集
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
X, y = cancer.data, cancer.target
print(f"特征数量: { X.shape[1]}")
print(f"样本数量: { len(y)},其中0代表良性,1代表恶性")
加载糖尿病(Diabetes)数据集
from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
print(f"特征数量: { X.shape[1]}")
print(f"患者数量: { len(y)}")
下载和加载20个新闻组(20newsgroups)文本数据集的一个子集
from sklearn.datasets import fetch_20newsgroups
# 仅下载训练子集并取部分类别
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
newsgroups = fetch_20newsgroups(subset='train', categories=categories)code>
X, y = newsgroups.data, newsgroups.target
print(f"文档数量: { len(X)}")
print(f"类别数量: { len(set(y))}")
从OpenML获取MNIST手写数字数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist.data, mnist.target
print(f"图像数量: { len(X)}")
print(f"每张图像的特征数量: { X[0].shape[0]}")
请注意,在实际使用中,您可能需要根据自己的需求对这些数据集进行预处理,例如特征缩放、编码分类变量等。此外,fetch_openml
可能需要网络连接以下载数据集。如果您使用的是Jupyter Notebook或类似的环境,并且数据集很大,下载可能需要一些时间。
五、详细使用示例
load_digits 完整使用
该函数加载了一个手写数字数据集,其中包含了1797个8x8的数字图像以及对应的数字类别。
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# 加载手写数字数据集
digits = load_digits()
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
# 创建一个支持向量机分类器
clf = SVC(gamma=0.001)
# 在训练集上训练分类器
clf.fit(X_train, y_train)
# 使用训练好的分类器对测试集进行预测
y_pred = clf.predict(X_test)
# 计算预测的准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: { accuracy:.2f}")
# 可视化测试集中的一些图像及其预测结果
test_images = X_test[:16] # 取测试集的前16张图像
test_labels = y_test[:16] # 和对应的标签
predictions = y_pred[:16] # 和对应的预测结果
images_to_show = test_images.reshape(-1, 8, 8) # 将图像数据重塑为8x8的形状
fig, axes = plt.subplots(4, 4)
fig.subplots_adjust(hspace=1, wspace=0.5)
for i, ax in enumerate(axes.flat):
ax.imshow(images_to_show[i], cmap='gray_r')code>
ax.set_title(f"True: { test_labels[i]}")
ax.set_xlabel(f"Prediction: { predictions[i]}")
ax.set_xticks(())
ax.set_yticks(())
plt.show()
fetch_20newsgroups 完整使用
fetch_20newsgroups 是 scikit-learn 库中的一个函数,用于加载著名的 20 Newsgroups 文本数据集。这个数据集包含了大约 20,000 个新闻组文档,均匀分为 20 个不同的类别。
常用于文本挖掘、主题建模、情感分析等领域的实验和演示。
<code>from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
# 加载 20 Newsgroups 数据集
newsgroups = fetch_20newsgroups(subset='all') code>
# 显示数据集的一些基本信息
print("数据集大小:", len(newsgroups.data))
print("目标类别:", newsgroups.target_names)
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(newsgroups.data, newsgroups.target, test_size=0.2, random_state=42)
# 使用 CountVectorizer 进行文本向量化
vectorizer = CountVectorizer(stop_words='english') code>
X_train_vectors = vectorizer.fit_transform(X_train)
X_test_vectors = vectorizer.transform(X_test)
# 使用多项式朴素贝叶斯分类器进行训练
clf = MultinomialNB()
clf.fit(X_train_vectors, y_train)
# 在测试集上进行预测
y_pred = clf.predict(X_test_vectors)
# 打印分类报告
print(classification_report(y_test, y_pred, target_names=newsgroups.target_names))
打印输出
数据集大小: 18846
目标类别: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']
precision recall f1-score support
alt.atheism 0.86 0.91 0.88 151
comp.graphics 0.67 0.92 0.78 202
comp.os.ms-windows.misc 0.96 0.35 0.52 195
comp.sys.ibm.pc.hardware 0.61 0.86 0.71 183
comp.sys.mac.hardware 0.89 0.90 0.90 205
comp.windows.x 0.85 0.85 0.85 215
misc.forsale 0.93 0.70 0.80 193
rec.autos 0.90 0.95 0.93 196
rec.motorcycles 0.95 0.94 0.95 168
rec.sport.baseball 0.98 0.96 0.97 211
rec.sport.hockey 0.95 0.97 0.96 198
sci.crypt 0.92 0.96 0.94 201
sci.electronics 0.92 0.83 0.87 202
sci.med 0.95 0.95 0.95 194
sci.space 0.92 0.97 0.95 189
soc.religion.christian 0.88 0.99 0.93 202
talk.politics.guns 0.89 0.93 0.91 188
talk.politics.mideast 0.95 1.00 0.97 182
talk.politics.misc 0.84 0.89 0.87 159
talk.religion.misc 0.95 0.58 0.72 136
accuracy 0.87 3770
macro avg 0.89 0.87 0.87 3770
weighted avg 0.89 0.87 0.87 3770
六、常用函数详细介绍
make_regression 生成一个随机的回归问题
make_regression
是sklearn.datasets
模块中的一个函数,用于生成回归问题的样本数据。以下是关于make_regression
函数的详细介绍:
函数语法
sklearn.datasets.make_regression(n_samples=100, n_features=100, n_informative=10, n_targets=1, bias=0.0, effective_rank=None, tail_strength=0.5, noise=0.0, shuffle=True, coef=False, random_state=None)
参数说明
n_samples:整数,可选(默认为100)。表示生成的样本数量。n_features:整数,可选(默认为100)。表示每个样本的特征数量,即自变量的个数。n_informative:整数,可选(默认为10)。表示有信息的特征数量,这些特征将被用来构造线性模型以生成输出。换句话说,这些特征是对预测目标变量有用的。n_targets:整数,可选(默认为1)。表示回归目标的数量,即对应于一个样本输出向量y的维度。在大多数情况下,这是一个标量值,但也可以设置为多个目标。bias:浮点数,可选(默认为0.0)。表示偏置项或截距。effective_rank:整数或None,可选(默认为None)。如果非None,则数据是通过这种方式生成的:先生成一个正态分布的随机矩阵,然后用其奇异值分解的前effective_rank
个奇异向量和奇异值来构造数据矩阵。如果为None,则使用完整的奇异值分解来生成数据。tail_strength:浮点数,可选(默认为0.5)。该参数影响了特征之间的相关性衰减速度。较大的tail_strength
值意味着特征之间将具有更强的相关性。noise:浮点数,可选(默认为0.0)。表示添加到输出中的高斯噪声的标准差。噪声是添加到由线性模型生成的输出上的。shuffle:布尔值,可选(默认为True)。如果为True,则打乱样本和特征的顺序。如果为False,则按原始顺序返回样本和特征。coef:布尔值,可选(默认为False)。如果为True,则返回数据的系数(即线性模型的权重)。这些系数可用于了解哪些特征对目标变量的影响更大。random_state:整数、RandomState实例或None(默认为None)。控制随机数生成器的种子或RandomState实例。如果为整数,则它指定了随机数生成器的种子;如果为RandomState实例,则它本身就是随机数生成器;如果为None,则随机数生成器是np.random
。设置随机状态可以确保每次运行代码时生成相同的数据集,便于结果复现和比较。
返回值
函数返回以下值:
X:形状为(n_samples, n_features)
的数组,表示生成的特征数据。y:形状为(n_samples,)
或(n_samples, n_targets)
的数组,表示生成的目标变量(回归值)。如果设置了多个目标(n_targets > 1
),则y的形状将为(n_samples, n_targets)
。coef(仅在coef=True
时返回):形状为(n_informative,)
或(n_targets, n_informative)
的数组,表示用于生成数据的线性模型的系数。这些系数可以帮助理解哪些特征对目标变量的预测最重要。如果设置了多个目标(n_targets > 1
),则coef的形状将为(n_targets, n_informative)
,其中每一行对应于一个目标的系数向量。
train_test_split 划分训练集和测试集
train_test_split
是 scikit-learn
库中的一个函数,用于将数据集随机划分为训练集和测试集。这是机器学习中常用的一个步骤,目的是确保模型不仅在训练数据上表现良好,还能在未见过的数据(即测试数据)上有良好的表现。
函数的基本用法如下:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X
和 y
分别是特征和标签的数据。test_size
参数指定了测试集的大小,这里是总数据的20%。random_state
参数是一个随机种子,用于确保每次划分数据时都能得到相同的结果。这在需要重复实验或比较不同模型时非常有用。
函数返回四个数组:X_train
(训练集的特征)、X_test
(测试集的特征)、y_train
(训练集的标签)和y_test
(测试集的标签)。
classification_report 分类报告
classification_report
是 scikit-learn
库中的一个函数,它用于显示主要分类指标的文本报告,包括每个类的精确度(precision)、召回率(recall)、F1 分数(F1-score)以及支持数(即每个类别的样本数)。这个函数对于评估分类模型的性能非常有用,因为它提供了一个全面的、易于理解的性能指标概览。
下面是 classification_report
函数的基本用法:
from sklearn.metrics import classification_report, accuracy_score
y_true = [0, 0, 1, 2, 2, 2]
y_pred = [0, 0, 0, 2, 2, 1]
target_names = ['class 0', 'class 1', 'class 2']
print(classification_report(y_true, y_pred, target_names=target_names))
输出将类似于以下内容:
precision recall f1-score support
class 0 0.67 1.00 0.80 2
class 1 0.00 0.00 0.00 1
class 2 1.00 0.67 0.80 3
accuracy 0.67 6
macro avg 0.56 0.56 0.53 6
weighted avg 0.72 0.67 0.67 6
在这个报告中:
precision
(精确度)表示预测为正且实际为正的样本占所有预测为正的样本的比例。recall
(召回率)表示预测为正且实际为正的样本占所有实际为正的样本的比例。f1-score
是精确度和召回率的调和平均数,用于综合评价模型的性能。support
表示每个类别的样本数。
classification_report
还提供了 macro avg
和 weighted avg
,分别表示所有类别的指标的平均值(未加权的平均值)和根据每个类别的支持数加权的平均值。这些平均值有助于从整体上评估模型的性能。
计算公式解析
为了解释 precision、recall 和 F1-score 是如何计算的,我们先要理解几个基本概念:
真正例(True Positives, TP):实际为正例且被预测为正例的样本数。假正例(False Positives, FP):实际为负例但被预测为正例的样本数。真负例(True Negatives, TN):实际为负例且被预测为负例的样本数(在分类报告中通常不直接使用)。假负例(False Negatives, FN):实际为正例但被预测为负例的样本数。
现在,我们来看具体的计算方式:
1)精确度(Precision)
精确度是指预测为正例的样本中,真正为正例的比例。对于每个类别,其计算公式为:
[ \text{Precision} = \frac{TP}{TP + FP} ]
2)召回率(Recall)
召回率是指所有真正的正例中,被正确预测出来的比例。对于每个类别,其计算公式为:
[ \text{Recall} = \frac{TP}{TP + FN} ]
3)F1 分数(F1-score)
F1 分数是精确度和召回率的调和平均数,用于提供一个单一的指标来衡量分类器的性能。其计算公式为:
[ F1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} ]
参考
https://www.cnblogs.com/Zshirly/p/15871498.htmlhttp://openml.org
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。