• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

简单的线性分类——MATLAB,python3实现

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

  看李政轩老师讲的Kernel,讲的非常好!前面有几道作业题,用MATLAB简单做了下,不知道对不对,错误之处还请指出。

  题目是这样的。

  一、MATLAB版本:

clear;
clc
% 生成training sample
MU1 = [1 2];
MU2 = [4 6];
SIGMA1 = [4 4; 4 9];
SIGMA2 = [4 2; 2 4];

M1 = mvnrnd(MU1,SIGMA1,100);
M2 = mvnrnd(MU2,SIGMA2,100);

%生成testing sample
TEST1 = mvnrnd(MU1,SIGMA1,50);
TEST2 = mvnrnd(MU2,SIGMA2,50);

%向量化的计算
%中间点C
C = (MU1+MU2)/2;
C_M = repmat(C,50,1);

%MUi vector
TRAIN_V = MU1 - MU2;
TRAIN_V_M = repmat(TRAIN_V,50,1);

%TEST vector
TEST1_V = TEST1 - C_M;
TEST2_V = TEST2 - C_M;

%预测第一个测试集
num1 = 0;
for (i=1:50)
    d = dot(TRAIN_V,TEST1_V(i,:));
    if d >0
        num1 = num1 + 1;
    end
end

disp(['测试集1输入数据数量为:',num2str(length(TEST1_V)),'正确分类的数量为:',num2str(num1)])
disp(['测试集1的预测准确度为:',num2str(num1/length(TEST1_V))])

num2 = 0;
for (i=1:50)
    d = dot(TRAIN_V,TEST2_V(i,:));
    if d <0
        num2 = num2 + 1;
    end
end

disp(['测试集2输入数据数量为:',num2str(length(TEST2_V)),'正确分类的数量为:',num2str(num2)])
disp(['测试集2的预测准确度为:',num2str(num2/length(TEST2_V))])

%两样本中心值连线的斜率
K = TRAIN_V(2)/TRAIN_V(1);
%两样本中心值连线的中垂线的斜率
k = K/(-1);

x = min(TEST1):0.1:max(TEST2);
y = k*(x-C(1))+C(2);

 plot(TEST1,TEST2,'O',MU1,MU2,'o',x,y)

  输出如下:

  作图:

   

  二、python3版本

  注意这里原始的training data 做了改动,原理是一样的。

# -*- coding: utf-8 -*-
"""
Created on Sun Nov  6 20:02:02 2016

@author: Administrator
"""

import numpy as np
from matplotlib import pyplot as plt


# train matrix
def get_train_data():		
	M1 = np.random.random((100,2))
	M2 = np.random.random((100,2)) - 0.7
	plt.plot(M1[:,0],M1[:,1], 'ro')
	plt.plot(M2[:,0],M2[:,1], 'go')
	return M1,M2

def classify(M1,M2,test_data):
	mean1 = np.mean(M1, axis=0)
	mean2 = np.mean(M2, axis=0)
	mean = (mean1 + mean2)/2
	# for plot
	km = (mean1[1]-mean2[1])/(mean1[0]-mean2[0])
	k = km/(-1)
	min_x = np.min(M2)
	max_x = np.max(M1)
	x = np.linspace(min_x, max_x, 100)
	y = k*(x-mean[0])+mean[1]
	plt.plot(x,y,'y')
	
	vector_train = mean1 - mean
	vector_test = test_data - mean
	vector_dot = np.dot(vector_train, vector_test)
	sgn = np.sign(vector_dot)
	
	return sgn
	
def get_test_data():
	M = np.random.random((50,2))
	plt.plot(M[:,0],M[:,1],'*y')
	return M

if __name__=="__main__":
	M1,M2 = get_train_data()
	test_data = get_test_data()
	right_count = 0
	for test_i in test_data:
		classx = classify(M1,M2,test_i)
		if classx == 1:
			right_count += 1
	plt.show()
	print("The accuracy of right classification is %s"%str(right_count/len(test_data)))

  输出:

 

 


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap