《机器学习算法与实现 —— Python编程与应用实例》KNN算法
<div class='showpostmsg'><p>K最近邻(k-Nearest Neighbor,kNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:<strong>如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。</strong>简单来说,kNN可以看成:<strong>有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就开始跟训练数据里的每个点求距离,然后挑选这个训练数据最近的K个点,看看这几个点属于什么类型,然后用少数服从多数的原则,给新数据归类</strong>。kNN算法不仅可以用于分类,还可以用于回归。</p><p> </p>
<p>跟着书中的实验学习一下</p>
<p>生成训练数据和测试数据</p>
<pre>
<code class="language-python">%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
# 生成模拟数据
np.random.seed(314)
data_size1 = 1000
x1 = np.random.randn(data_size1, 2)*2 + np.array()
y1 =
data_size2 = 1000
x2 = np.random.randn(data_size2, 2)*2 + np.array()
y2 =
# 合并生成全部数据
x = np.concatenate((x1, x2), axis=0)
y = np.concatenate((y1, y2), axis=0)
data_size_all = data_size1 + data_size2
shuffled_index = np.random.permutation(data_size_all)
x = x
y = y
# 分割训练与测试数据
split_index = int(data_size_all*0.7)
x_train = x[:split_index]
y_train = y[:split_index]
x_test = x
y_test = y
# 绘制结果
for i in range(split_index):
if y_train == 0:
plt.scatter(x_train,x_train, s=38, c = 'r', marker='.')
else:
plt.scatter(x_train,x_train, s=38, c = 'b', marker='^')
#plt.rcParams['figure.figsize']=(12.0, 8.0)
mpl.rcParams['font.family'] = 'SimHei'
plt.title("训练数据")
plt.savefig("fig-res-knn-traindata.pdf")
plt.show()
for i in range(data_size_all - split_index):
if y_test == 0:
plt.scatter(x_test,x_test, s=38, c = 'r', marker='.')
else:
plt.scatter(x_test,x_test, s=38, c = 'b', marker='^')
#plt.rcParams['figure.figsize']=(12.0, 8.0)
mpl.rcParams['font.family'] = 'SimHei'
plt.title("测试数据")
plt.savefig("fig-res-knn-testdata.pdf")
plt.show()</code></pre>
<p> </p>
<p> </p>
<p>KNN实现主要分为三步:计算距离,按距离取样本,投票决定类别。分别定义这三步函数为knn_distance(),knn_vote(),knn_predict()</p>
<pre>
<code class="language-python">import numpy as np
import operator
def knn_distance(v1, v2):
"""计算两个多维向量的距离"""
return np.sum(np.square(v1-v2))
def knn_vote(ys):
"""根据ys的类别,挑选类别最多一类作为输出"""
vote_dict = {}
for y in ys:
if y not in vote_dict.keys():
vote_dict = 1
else:
vote_dict += 1
method = 1
# 方法1 - 使用循环遍历找到类别最多的一类
if method == 1:
maxv = maxk = 0
for y in np.unique(ys):
if maxv < vote_dict:
maxv = vote_dict
maxk = y
return maxk
# 方法2 - 使用排序的方法
if method == 2:
sorted_vote_dict = sorted(vote_dict.items(), \
#key=operator.itemgetter(1), \
key=lambda x:x, \
reverse=True)
return sorted_vote_dict
def knn_predict(x, train_x, train_y, k=3):
"""
针对给定的数据进行分类
参数
x - 输入的待分类样本
train_x - 训练数据的样本
train_y - 训练数据的标签
k - 最近邻的样本个数
"""
dist_arr = ) for j in range(len(train_x))]
sorted_index = np.argsort(dist_arr)
top_k_index = sorted_index[:k]
ys=train_y
return knn_vote(ys)
# 对每个样本进行分类
y_train_est = , x_train, y_train, k=5) for i in range(len(x_train))]
print(y_train_est)
# 绘制结果
for i in range(len(y_train_est)):
if y_train_est == 0:
plt.scatter(x_train,x_train, s=38, c = 'r', marker='.')
else:
plt.scatter(x_train,x_train, s=38, c = 'b', marker='^')
#plt.rcParams['figure.figsize']=(12.0, 8.0)
mpl.rcParams['font.family'] = 'SimHei'
plt.title("Train Results")
plt.savefig("fig-res-knn-train-res.pdf")
plt.show()</code></pre>
<p> </p>
<p>计算一下训练和测试精度</p>
<pre>
<code class="language-python"># 计算训练数据的精度
n_correct = 0
for i in range(len(x_train)):
if y_train_est == y_train:
n_correct += 1
accuracy = n_correct / len(x_train) * 100.0
print("Train Accuracy: %f%%" % accuracy)
# 计算测试数据的精度
y_test_est = , x_train, y_train, 3) for i in range(len(x_test))]
n_correct = 0
for i in range(len(x_test)):
if y_test_est == y_test:
n_correct += 1
accuracy = n_correct / len(x_test) * 100.0
print("Test Accuracy: %f%%" % accuracy)
print(n_correct, len(x_test))</code></pre>
<p>Train Accuracy: 97.857143%</p>
<pre>
Train Accuracy: 96.666667%%</pre>
<pre>
58 60</pre>
<p>至此验证了KNN算法的可行性和可靠性。</p>
<p> </p>
</div><script> var loginstr = '<div class="locked">查看本帖全部内容,请<a href="javascript:;" style="color:#e60000" class="loginf">登录</a>或者<a href="https://bbs.eeworld.com.cn/member.php?mod=register_eeworld.php&action=wechat" style="color:#e60000" target="_blank">注册</a></div>';
if(parseInt(discuz_uid)==0){
(function($){
var postHeight = getTextHeight(400);
$(".showpostmsg").html($(".showpostmsg").html());
$(".showpostmsg").after(loginstr);
$(".showpostmsg").css({height:postHeight,overflow:"hidden"});
})(jQuery);
} </script><script type="text/javascript">(function(d,c){var a=d.createElement("script"),m=d.getElementsByTagName("script"),eewurl="//counter.eeworld.com.cn/pv/count/";a.src=eewurl+c;m.parentNode.insertBefore(a,m)})(document,523)</script> <p>kNN算法不仅可以用于分类,还可以用于回归用途还是比较广的</p>
页:
[1]