KNN 算法的工作原理非常直观:给定一个样本,KNN 会检查这个样本距离训练集中所有其他样本的距离,选择距离最近的 K 个样本,根据这 K 个样本的标签来预测目标样本的标签(分类问题)或数值(回归问题)。
👇 具体步骤:
- 选择一个适当的 K 值:即选择查看多少个邻居的标签。
- 计算距离:计算待分类样本与所有训练样本的距离(常用欧式距离)。
- 选择 K 个最近的邻居:找出距离最近的 K 个训练样本。
- 进行投票(分类)或平均(回归):
- 分类:选择 K 个邻居中出现最多的类别作为预测结果。
- 回归:计算 K 个邻居的均值作为预测结果。
📐 欧氏距离(Euclidean Distance)
KNN 算法通常使用欧氏距离来度量两个样本之间的相似度:
d(x, y) = sqrt( (x1 - y1)^2 + (x2 - y2)^2 + ... + (xn - yn)^2 )
x
和y
是两个样本,x1, x2,... xn
和y1, y2,... yn
是它们的特征。
🎯 KNN 适用场景
- 分类问题:例如判断图像中的物体类别、电子邮件是否为垃圾邮件等。
- 回归问题:例如预测房价、股票价格等。
✅ Python 实现 KNN(用 scikit-learn
)
🧪 示例:使用 KNN 分类鸢尾花数据集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建 KNN 模型(选择 K=3)
knn = KNeighborsClassifier(n_neighbors=3)
# 训练模型
knn.fit(X_train, y_train)
# 预测
y_pred = knn.predict(X_test)
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print("模型准确率:", accuracy)
📉 KNN 的优缺点
✅ 优点 | ❌ 缺点 |
---|---|
简单易懂,直观,零训练阶段 | 测试阶段需要计算每个样本与所有训练样本的距离,计算量大 |
可以处理多分类问题 | 存储成本高,需要存储所有训练数据 |
不需要假设数据分布或做特征变换 | 对异常值敏感 |
适用于小型数据集和低维特征数据 | 大数据时性能较差,计算慢 |
🔧 KNN 调参
- K 值的选择:K 值过小可能会导致过拟合,K 值过大可能会导致欠拟合。
- 一般通过交叉验证选择最优的 K 值。
- 距离度量方法:除了欧式距离,还可以使用曼哈顿距离、闵可夫斯基距离等。
- 加权投票:在分类时,近邻的投票可以加权(距离越近的权重越大)。
🛠 应用场景
- 图像分类:将图像映射到特征空间,进行分类。
- 推荐系统:通过相似用户的历史行为进行推荐。
- 模式识别:例如,手写数字识别。
- 医学诊断:通过病人的症状相似度判断可能的疾病。
🚀 KNN 的进阶优化
对于大规模数据,KNN 的计算量比较大,因为每次预测都要计算所有训练样本的距离。一些优化方式包括:
- KD 树:高效地查找最近邻。
- 球树(Ball Tree):适用于高维数据。
- 近似最近邻:通过聚类或降维方法来加速最近邻查找。
🤝 想练手吗?
如果你有兴趣,我们可以试试用 KNN 来做个小项目:
- 手写数字识别(用 KNN 进行分类)
- 预测房价(用 KNN 进行回归)
或者直接带你一起实现一个实际的案例。你想试哪个呢?😊
发表回复