# 數據挖掘十大算法--K近鄰算法
> 來源:http://blog.csdn.net/u011067360/article/details/23941577
_k_-近鄰算法是基于實例的學習方法中最基本的,先介紹基于實例學習的相關概念。
**一、基于實例的學習。**
1、已知一系列的訓練樣例,很多學習方法為目標函數建立起明確的一般化描述;但與此不同,基于實例的學習方法只是簡單地把訓練樣例存儲起來。
從這些實例中泛化的工作被推遲到必須分類新的實例時。每當學習器遇到一個新的查詢實例,它分析這個新實例與以前存儲的實例的關系,并據此把一個目標函數值賦給新實例。
2、基于實例的方法可以為不同的待分類查詢實例建立不同的目標函數逼近。事實上,很多技術只建立目標函數的局部逼近,將其應用于與新查詢實例鄰近的實例,而從不建立在整個實例空間上都表現良好的逼近。當目標函數很復雜,但它可用不太復雜的局部逼近描述時,這樣做有顯著的優勢。
3、基于實例方法的不足:
(1)分類新實例的開銷可能很大。這是因為幾乎所有的計算都發生在分類時,而不是在第一次遇到訓練樣例時。所以,如何有效地索引訓練樣例,以減少查詢時所需計算是一個重要的實踐問題。
(2)當從存儲器中檢索相似的訓練樣例時,它們一般考慮實例的所有屬性。如果目標概念僅依賴于很多屬性中的幾個時,那么真正最“相似”的實例之間很可能相距甚遠。
## 二、k-近鄰法
基于實例的學習方法中最基本的是_k_-近鄰算法。這個算法假定所有的實例對應于_n_維歐氏空間?_<sup>n</sup>_中的點。一個實例的最近鄰是根據標準歐氏距離定義的。更精確地講,把任意的實例_x_表示為下面的特征向量:
_a_<sub>1</sub>(_x_),_a_<sub>2</sub>(_x_),...,_a<sub>n</sub>_(_x_)
其中_a<sub>r</sub>_(_x_)表示實例_x_的第_r_個屬性值。那么兩個實例_x<sub>i</sub>_和_x<sub>j</sub>_間的距離定義為_d_(_x<sub>i</sub>_,_x<sub>j</sub>_),其中:

說明:
1、在最近鄰學習中,目標函數值可以為離散值也可以為實值。
2、我們先考慮學習以下形式的離散目標函數。其中_V_是有限集合{_v_<sub>1,</sub>..._v<sub>s</sub>_}。下表給出了逼近離散目標函數的_k-_近鄰算法。
3、正如下表中所指出的,這個算法的返回值_f'_(_x<sub>q</sub>_)為對_f_(_x<sub>q</sub>_)的估計,它就是距離_x<sub>q</sub>_最近的_k_個訓練樣例中最普遍的_f_值。
4、如果我們選擇_k_=1,那么“1-近鄰算法”就把_f_(_x<sub>i</sub>_)賦給(_x<sub>q</sub>_),其中_x<sub>i</sub>_是最靠近_x<sub>q</sub>_的訓練實例。對于較大的_k_值,這個算法返回前_k_個最靠近的訓練實例中最普遍的_f_值。
**逼近離散值函數_f_: ?_<sup>n_</sup>V_的_k_-近鄰算法**
> 訓練算法:
> 對于每個訓練樣例<_x_,_f_(_x_)>,把這個樣例加入列表_training___examples_
> 分類算法:
> 給定一個要分類的查詢實例_x<sub>q</sub>_
> 在_training___examples_中選出最靠近_x<sub>q</sub>_的_k_個實例,并用_x_<sub>1....</sub>_x<sub>k</sub>_表示
> 返回
> 
> 其中如果_a_=_b_那么_d_(_a_,_b_)=1,否則_d_(_a_,_b_)=0。
下圖圖解了一種簡單情況下的_k_-近鄰算法,在這里實例是二維空間中的點,目標函數具有布爾值。正反訓練樣例用“+”和“-”分別表示。圖中也畫出了一個查詢點_x<sub>q</sub>_。注意在這幅圖中,1-近鄰算法把_x<sub>q</sub>_分類為正例,然而5-近鄰算法把_x<sub>q</sub>_分類為反例。

圖解說明:左圖畫出了一系列的正反訓練樣例和一個要分類的查詢實例_x<sub>q</sub>_。1-近鄰算法把_x<sub>q</sub>_分類為正例,然而5-近鄰算法把_x<sub>q</sub>_分類為反例。
右圖是對于一個典型的訓練樣例集合1-近鄰算法導致的決策面。圍繞每個訓練樣例的凸多邊形表示最靠近這個點的實例空間(即這個空間中的實例會被1-近鄰算法賦予該訓練樣例所屬的分類)。
對前面的_k_-近鄰算法作簡單的修改后,它就可被用于逼近連續值的目標函數。為了實現這一點,我們讓算法計算_k_個最接近樣例的平均值,而不是計算其中的最普遍的值。更精確地講,為了逼近一個實值目標函數,我們只要把算法中的公式替換為:

### 三、距離加權最近鄰算法
對_k_-近鄰算法的一個顯而易見的改進是對_k_個近鄰的貢獻加權,根據它們相對查詢點_x<sub>q</sub>_的距離,將較大的權值賦給較近的近鄰。
例如,在上表逼近離散目標函數的算法中,我們可以根據每個近鄰與_x<sub>q</sub>_的距離平方的倒數加權這個近鄰的“選舉權”。
方法是通過用下式取代上表算法中的公式來實現:

其中

為了處理查詢點_x<sub>q</sub>_恰好匹配某個訓練樣例_x<sub>i</sub>_,從而導致分母為0的情況,我們令這種情況下的_f '(x<sub>q</sub>_)等于_f_(_x<sub>i</sub>_)。如果有多個這樣的訓練樣例,我們使用它們中占多數的分類。
我們也可以用類似的方式對實值目標函數進行距離加權,只要用下式替換上表的公式:

其中_w<sub>i</sub>_的定義與之前公式中相同。
注意這個公式中的分母是一個常量,它將不同權值的貢獻歸一化(例如,它保證如果對所有的訓練樣例_x<sub>i</sub>_,_f_(_x<sub>i</sub>_)=_c_,那么(_x<sub>q</sub>_)<--_c_)。
注意以上k-近鄰算法的所有變體都只考慮k個近鄰以分類查詢點。如果使用按距離加權,那么允許所有的訓練樣例影響_x<sub>q</sub>_的分類事實上沒有壞處,因為非常遠的實例對(_x<sub>q</sub>_)的影響很小。考慮所有樣例的惟一不足是會使分類運行得更慢。如果分類一個新的查詢實例時考慮所有的訓練樣例,我們稱此為全局(global)法。如果僅考慮最靠近的訓練樣例,我們稱此為局部(local)法。
**四、對_k_-近鄰算法的說明**
按距離加權的_k_-近鄰算法是一種非常有效的歸納推理方法。它對訓練數據中的噪聲有很好的魯棒性,而且當給定足夠大的訓練集合時它也非常有效。注意通過取_k_個近鄰的加權平均,可以消除孤立的噪聲樣例的影響。
1、問題一:近鄰間的距離會被大量的不相關屬性所支配。
應用_k_-近鄰算法的一個實踐問題是,實例間的距離是根據實例的所有屬性(也就是包含實例的歐氏空間的所有坐標軸)計算的。這與那些只選擇全部實例屬性的一個子集的方法不同,例如決策樹學習系統。
比如這樣一個問題:每個實例由20個屬性描述,但在這些屬性中僅有2個與它的分類是有關。在這種情況下,這兩個相關屬性的值一致的實例可能在這個20維的實例空間中相距很遠。結果,依賴這20個屬性的相似性度量會誤導_k_-近鄰算法的分類。近鄰間的距離會被大量的不相關屬性所支配。這種由于存在很多不相關屬性所導致的難題,有時被稱為維度災難(curse of dimensionality)。最近鄰方法對這個問題特別敏感。
2、解決方法:當計算兩個實例間的距離時對每個屬性加權。
這相當于按比例縮放歐氏空間中的坐標軸,縮短對應于不太相關屬性的坐標軸,拉長對應于更相關的屬性的坐標軸。每個坐標軸應伸展的數量可以通過交叉驗證的方法自動決定。
?
3、問題二:應用_k_-近鄰算法的另外一個實踐問題是如何建立高效的索引。因為這個算法推遲所有的處理,直到接收到一個新的查詢,所以處理每個新查詢可能需要大量的計算。
4、解決方法:目前已經開發了很多方法用來對存儲的訓練樣例進行索引,以便在增加一定存儲開銷情況下更高效地確定最近鄰。一種索引方法是_kd_-tree(Bentley 1975;Friedman et al. 1977),它把實例存儲在樹的葉結點內,鄰近的實例存儲在同一個或附近的結點內。通過測試新查詢_x<sub>q</sub>_的選定屬性,樹的內部結點把查詢_x<sub>q</sub>_排列到相關的葉結點。
# 機器學習與數據挖掘-K最近鄰(KNN)算法的實現(java和python版)
> 來源:http://blog.csdn.net/u011067360/article/details/45937327
KNN算法基礎思想前面文章可以參考,這里主要講解java和python的兩種簡單實現,也主要是理解簡單的思想。
**python版本:**
這里實現一個手寫識別算法,這里只簡單識別0~9熟悉,在上篇文章中也展示了手寫識別的應用,可以參考:[機器學習與數據挖掘-logistic回歸及手寫識別實例的實現](http://blog.csdn.net/u011067360/article/details/45624517)
輸入:每個手寫數字已經事先處理成32*32的二進制文本,存儲為txt文件。0~9每個數字都有10個訓練樣本,5個測試樣本。訓練樣本集如下圖:左邊是文件目錄,右邊是其中一個文件打開顯示的結果,看著像1,這里有0~9,每個數字都有是個樣本來作為訓練集。

第一步:將每個txt文本轉化為一個向量,即32*32的數組轉化為1*1024的數組,這個1*1024的數組用機器學習的術語來說就是特征向量。
```
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
```
第二步:訓練樣本中有10*10個圖片,可以合并成一個100*1024的矩陣,每一行對應一個圖片,也就是一個txt文檔。
```
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
print trainingFileList
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
#print hwLabels
#print fileNameStr
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
#print trainingMat[i,:]
#print len(trainingMat[i,:])
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount/float(mTest))
```
第三步:測試樣本中有10*5個圖片,同樣的,對于測試圖片,將其轉化為1*1024的向量,然后計算它與訓練樣本中各個圖片的“距離”(這里兩個向量的距離采用歐式距離),然后對距離排序,選出較小的前k個,因為這k個樣本來自訓練集,是已知其代表的數字的,所以被測試圖片所代表的數字就可以確定為這k個中出現次數最多的那個數字。
```
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#tile(A,(m,n))
print dataSet
print "----------------"
print tile(inX, (dataSetSize,1))
print "----------------"
diffMat = tile(inX, (dataSetSize,1)) - dataSet
print diffMat
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
```
**全部實現代碼:**
```
#-*-coding:utf-8-*-
from numpy import *
import operator
from os import listdir
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#tile(A,(m,n))
print dataSet
print "----------------"
print tile(inX, (dataSetSize,1))
print "----------------"
diffMat = tile(inX, (dataSetSize,1)) - dataSet
print diffMat
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
print trainingFileList
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
#print hwLabels
#print fileNameStr
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
#print trainingMat[i,:]
#print len(trainingMat[i,:])
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount/float(mTest))
handwritingClassTest()
```
運行結果:源碼文章尾可下載

**java版本**
先看看訓練集和測試集:
訓練集:

測試集:

訓練集最后一列代表分類(0或者1)
代碼實現:
KNN算法主體類:
```
package Marchinglearning.knn2;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
/**
* KNN算法主體類
*/
public class KNN {
/**
* 設置優先級隊列的比較函數,距離越大,優先級越高
*/
private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
public int compare(KNNNode o1, KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
* 獲取K個不同的隨機數
* @param k 隨機數的個數
* @param max 隨機數最大的范圍
* @return 生成的隨機數數組
*/
public List<Integer> getRandKNum(int k, int max) {
List<Integer> rand = new ArrayList<Integer>(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
* 計算測試元組與訓練元組之前的距離
* @param d1 測試元組
* @param d2 訓練元組
* @return 距離值
*/
public double calDistance(List<Double> d1, List<Double> d2) {
System.out.println("d1:"+d1+",d2"+d2);
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
* 執行KNN算法,獲取測試元組的類別
* @param datas 訓練數據集
* @param testData 測試元組
* @param k 設定的K值
* @return 測試元組的類別
*/
public String knn(List<List<Double>> datas, List<Double> testData, int k) {
PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
List<Integer> randNum = getRandKNum(k, datas.size());
System.out.println("randNum:"+randNum.toString());
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List<Double> currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
System.out.println("currData:"+currData+",c:"+c+",testData"+testData);
//計算測試元組與訓練元組之前的距離
KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
pq.add(node);
}
for (int i = 0; i < datas.size(); i++) {
List<Double> t = datas.get(i);
System.out.println("testData:"+testData);
System.out.println("t:"+t);
double distance = calDistance(testData, t);
System.out.println("distance:"+distance);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
}
}
return getMostClass(pq);
}
/**
* 獲取所得到的k個最近鄰元組的多數類
* @param pq 存儲k個最近近鄰元組的優先級隊列
* @return 多數類的名稱
*/
private String getMostClass(PriorityQueue<KNNNode> pq) {
Map<String, Integer> classCount = new HashMap<String, Integer>();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c, classCount.get(c) + 1);
} else {
classCount.put(c, 1);
}
}
int maxIndex = -1;
int maxCount = 0;
Object[] classes = classCount.keySet().toArray();
for (int i = 0; i < classes.length; i++) {
if (classCount.get(classes[i]) > maxCount) {
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}
```
KNN結點類,用來存儲最近鄰的k個元組相關的信息
```
package Marchinglearning.knn2;
/**
* KNN結點類,用來存儲最近鄰的k個元組相關的信息
*/
public class KNNNode {
private int index; // 元組標號
private double distance; // 與測試元組的距離
private String c; // 所屬類別
public KNNNode(int index, double distance, String c) {
super();
this.index = index;
this.distance = distance;
this.c = c;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getC() {
return c;
}
public void setC(String c) {
this.c = c;
}
}
```
KNN算法測試類
```
package Marchinglearning.knn2;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
/**
* KNN算法測試類
*/
public class TestKNN {
/**
* 從數據文件中讀取數據
* @param datas 存儲數據的集合對象
* @param path 數據文件的路徑
*/
public void read(List<List<Double>> datas, String path){
try {
BufferedReader br = new BufferedReader(new FileReader(new File(path)));
String data = br.readLine();
List<Double> l = null;
while (data != null) {
String t[] = data.split(" ");
l = new ArrayList<Double>();
for (int i = 0; i < t.length; i++) {
l.add(Double.parseDouble(t[i]));
}
datas.add(l);
data = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 程序執行入口
* @param args
*/
public static void main(String[] args) {
TestKNN t = new TestKNN();
String datafile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator + "datafile.data";
String testfile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator +"testfile.data";
System.out.println("datafile:"+datafile);
System.out.println("testfile:"+testfile);
try {
List<List<Double>> datas = new ArrayList<List<Double>>();
List<List<Double>> testDatas = new ArrayList<List<Double>>();
t.read(datas, datafile);
t.read(testDatas, testfile);
KNN knn = new KNN();
for (int i = 0; i < testDatas.size(); i++) {
List<Double> test = testDatas.get(i);
System.out.print("測試元組: ");
for (int j = 0; j < test.size(); j++) {
System.out.print(test.get(j) + " ");
}
System.out.print("類別為: ");
System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
```
運行結果為:

資源下載:
[python版本下載](http://download.csdn.net/detail/u011067360/8731843)
[java版本下載](http://download.csdn.net/detail/u011067360/8731847)