python中如何实现径向基核函数
作者:柳叶吴钩 发布时间:2023-11-28 02:48:45
标签:python,径向基,核函数
1、生成数据集(双月数据集)
class moon_data_class(object):
def __init__(self,N,d,r,w):
self.N=N
self.w=w
self.d=d
self.r=r
def sgn(self,x):
if(x>0):
return 1;
else:
return -1;
def sig(self,x):
return 1.0/(1+np.exp(x))
def dbmoon(self):
N1 = 10*self.N
N = self.N
r = self.r
w2 = self.w/2
d = self.d
done = True
data = np.empty(0)
while done:
#generate Rectangular data
tmp_x = 2*(r+w2)*(np.random.random([N1, 1])-0.5)
tmp_y = (r+w2)*np.random.random([N1, 1])
tmp = np.concatenate((tmp_x, tmp_y), axis=1)
tmp_ds = np.sqrt(tmp_x*tmp_x + tmp_y*tmp_y)
#generate double moon data ---upper
idx = np.logical_and(tmp_ds > (r-w2), tmp_ds < (r+w2))
idx = (idx.nonzero())[0]
if data.shape[0] == 0:
data = tmp.take(idx, axis=0)
else:
data = np.concatenate((data, tmp.take(idx, axis=0)), axis=0)
if data.shape[0] >= N:
done = False
#print (data)
db_moon = data[0:N, :]
#print (db_moon)
#generate double moon data ----down
data_t = np.empty([N, 2])
data_t[:, 0] = data[0:N, 0] + r
data_t[:, 1] = -data[0:N, 1] - d
db_moon = np.concatenate((db_moon, data_t), axis=0)
return db_moon
2、k均值聚类
def k_means(input_cells, k_count):
count = len(input_cells) #点的个数
x = input_cells[0:count, 0]
y = input_cells[0:count, 1]
#随机选择K个点
k = rd.sample(range(count), k_count)
k_point = [[x[i], [y[i]]] for i in k] #保证有序
k_point.sort()
global frames
#global step
while True:
km = [[] for i in range(k_count)] #存储每个簇的索引
#遍历所有点
for i in range(count):
cp = [x[i], y[i]] #当前点
#计算cp点到所有质心的距离
_sse = [distance(k_point[j], cp) for j in range(k_count)]
#cp点到那个质心最近
min_index = _sse.index(min(_sse))
#把cp点并入第i簇
km[min_index].append(i)
#更换质心
k_new = []
for i in range(k_count):
_x = sum([x[j] for j in km[i]]) / len(km[i])
_y = sum([y[j] for j in km[i]]) / len(km[i])
k_new.append([_x, _y])
k_new.sort() #排序
if (k_new != k_point):#一直循环直到聚类中心没有变化
k_point = k_new
else:
return k_point,km
3、高斯核函数
高斯核函数,主要的作用是衡量两个对象的相似度,当两个对象越接近,即a与b的距离趋近于0,则高斯核函数的值趋近于1,反之则趋近于0,换言之:
两个对象越相似,高斯核函数值就越大
作用:
用于分类时,衡量各个类别的相似度,其中sigma参数用于调整过拟合的情况,sigma参数较小时,即要求分类器,加差距很小的类别也分类出来,因此会出现过拟合的问题;
用于模糊控制时,用于模糊集的隶属度。
def gaussian (a,b, sigma):
return np.exp(-norm(a-b)**2 / (2 * sigma**2))
4、求高斯核函数的方差
Sigma_Array = []
for j in range(k_count):
Sigma = []
for i in range(len(center_array[j][0])):
temp = Phi(np.array([center_array[j][0][i],center_array[j][1][i]]),np.array(center[j]))
Sigma.append(temp)
Sigma = np.array(Sigma)
Sigma_Array.append(np.cov(Sigma))
5、显示高斯核函数计算结果
gaussian_kernel_array = []
fig = plt.figure()
ax = Axes3D(fig)
for j in range(k_count):
gaussian_kernel = []
for i in range(len(center_array[j][0])):
temp = Phi(np.array([center_array[j][0][i],center_array[j][1][i]]),np.array(center[j]))
temp1 = gaussian(temp,Sigma_Array[0])
gaussian_kernel.append(temp1)
gaussian_kernel_array.append(gaussian_kernel)
ax.scatter(center_array[j][0], center_array[j][1], gaussian_kernel_array[j],s=20)
plt.show()
6、运行结果
7、完整代码
# coding:utf-8
import numpy as np
import pylab as pl
import random as rd
import imageio
import math
import random
import matplotlib.pyplot as plt
import numpy as np
import mpl_toolkits.mplot3d
from mpl_toolkits.mplot3d import Axes3D
from scipy import *
from scipy.linalg import norm, pinv
from matplotlib import pyplot as plt
random.seed(0)
#定义sigmoid函数和它的导数
def sigmoid(x):
return 1.0/(1.0+np.exp(-x))
def sigmoid_derivate(x):
return x*(1-x) #sigmoid函数的导数
class moon_data_class(object):
def __init__(self,N,d,r,w):
self.N=N
self.w=w
self.d=d
self.r=r
def sgn(self,x):
if(x>0):
return 1;
else:
return -1;
def sig(self,x):
return 1.0/(1+np.exp(x))
def dbmoon(self):
N1 = 10*self.N
N = self.N
r = self.r
w2 = self.w/2
d = self.d
done = True
data = np.empty(0)
while done:
#generate Rectangular data
tmp_x = 2*(r+w2)*(np.random.random([N1, 1])-0.5)
tmp_y = (r+w2)*np.random.random([N1, 1])
tmp = np.concatenate((tmp_x, tmp_y), axis=1)
tmp_ds = np.sqrt(tmp_x*tmp_x + tmp_y*tmp_y)
#generate double moon data ---upper
idx = np.logical_and(tmp_ds > (r-w2), tmp_ds < (r+w2))
idx = (idx.nonzero())[0]
if data.shape[0] == 0:
data = tmp.take(idx, axis=0)
else:
data = np.concatenate((data, tmp.take(idx, axis=0)), axis=0)
if data.shape[0] >= N:
done = False
#print (data)
db_moon = data[0:N, :]
#print (db_moon)
#generate double moon data ----down
data_t = np.empty([N, 2])
data_t[:, 0] = data[0:N, 0] + r
data_t[:, 1] = -data[0:N, 1] - d
db_moon = np.concatenate((db_moon, data_t), axis=0)
return db_moon
def distance(a, b):
return (a[0]- b[0]) ** 2 + (a[1] - b[1]) ** 2
#K均值算法
def k_means(input_cells, k_count):
count = len(input_cells) #点的个数
x = input_cells[0:count, 0]
y = input_cells[0:count, 1]
#随机选择K个点
k = rd.sample(range(count), k_count)
k_point = [[x[i], [y[i]]] for i in k] #保证有序
k_point.sort()
global frames
#global step
while True:
km = [[] for i in range(k_count)] #存储每个簇的索引
#遍历所有点
for i in range(count):
cp = [x[i], y[i]] #当前点
#计算cp点到所有质心的距离
_sse = [distance(k_point[j], cp) for j in range(k_count)]
#cp点到那个质心最近
min_index = _sse.index(min(_sse))
#把cp点并入第i簇
km[min_index].append(i)
#更换质心
k_new = []
for i in range(k_count):
_x = sum([x[j] for j in km[i]]) / len(km[i])
_y = sum([y[j] for j in km[i]]) / len(km[i])
k_new.append([_x, _y])
k_new.sort() #排序
if (k_new != k_point):#一直循环直到聚类中心没有变化
k_point = k_new
else:
pl.figure()
pl.title("N=%d,k=%d iteration"%(count,k_count))
for j in range(k_count):
pl.plot([x[i] for i in km[j]], [y[i] for i in km[j]], color[j%4])
pl.plot(k_point[j][0], k_point[j][1], dcolor[j%4])
return k_point,km
def Phi(a,b):
return norm(a-b)
def gaussian (x, sigma):
return np.exp(-x**2 / (2 * sigma**2))
if __name__ == '__main__':
#计算平面两点的欧氏距离
step=0
color=['.r','.g','.b','.y']#颜色种类
dcolor=['*r','*g','*b','*y']#颜色种类
frames = []
N = 200
d = -4
r = 10
width = 6
data_source = moon_data_class(N, d, r, width)
data = data_source.dbmoon()
# x0 = [1 for x in range(1,401)]
input_cells = np.array([np.reshape(data[0:2*N, 0], len(data)), np.reshape(data[0:2*N, 1], len(data))]).transpose()
labels_pre = [[1] for y in range(1, 201)]
labels_pos = [[0] for y in range(1, 201)]
labels=labels_pre+labels_pos
k_count = 2
center,km = k_means(input_cells, k_count)
test = Phi(input_cells[1],np.array(center[0]))
print(test)
test = distance(input_cells[1],np.array(center[0]))
print(np.sqrt(test))
count = len(input_cells)
x = input_cells[0:count, 0]
y = input_cells[0:count, 1]
center_array = []
for j in range(k_count):
center_array.append([[x[i] for i in km[j]], [y[i] for i in km[j]]])
Sigma_Array = []
for j in range(k_count):
Sigma = []
for i in range(len(center_array[j][0])):
temp = Phi(np.array([center_array[j][0][i],center_array[j][1][i]]),np.array(center[j]))
Sigma.append(temp)
Sigma = np.array(Sigma)
Sigma_Array.append(np.cov(Sigma))
gaussian_kernel_array = []
fig = plt.figure()
ax = Axes3D(fig)
for j in range(k_count):
gaussian_kernel = []
for i in range(len(center_array[j][0])):
temp = Phi(np.array([center_array[j][0][i],center_array[j][1][i]]),np.array(center[j]))
temp1 = gaussian(temp,Sigma_Array[0])
gaussian_kernel.append(temp1)
gaussian_kernel_array.append(gaussian_kernel)
ax.scatter(center_array[j][0], center_array[j][1], gaussian_kernel_array[j],s=20)
plt.show()
来源:https://blog.csdn.net/moge19/article/details/83217745


猜你喜欢
- 这篇文章主要介绍了Python socket聊天脚本代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 1、pyqtgraph库数据可视化效果还不错,特别是窗体程序中图像交互性较好;安装也很方便,用 pip 安装。2、在Python中新建一个
- mysql> select 'name',id from table_b; //'name' 不在ta
- 目录前言什么是socket?如何在 Python 中创建 socket 对象?Python 的套接字库中有多少种可用的套接字方法?服务器套接
- 本文实例讲述了python实现清屏的方法。分享给大家供大家参考。具体分析如下:一试:>>> import os>&g
- 画星星程序2-7-7主要使用turtle.forward前进操作和turtle.left左转操作在屏幕上画星星。#!/usr/bin/env
- 1、说明装饰本质上是一个Python函数,它能使其他函数在没有任何代码变化的情况下增加额外的功能。有了装饰,我们可以抽出大量与函数功能无关的
- 原始数据和目标数据实现SQL语句(最大)selectshop,month,greatest(dz,fz,sp) as maxfromtabl
- 环境: Python 3.6.4 + Pycharm Professional 2017.3.3 + PyQt5 + PyQt5-tools
- 大家都知道系统存储过程是无法用工具导出的(大家可以试试 >任务>生成SQL脚本) 因为系统存储过程一般是不让开发人员修改的。 需
- 即使打开了strict和warnings选项也无妨,下面代码并无错误和警告。#!/usr/bin/perluse strict;use wa
- 1. 创建shell脚本 vim backupdb.sh 创建脚本内容如下: #!/bin/sh db_user="root&qu
- 页面重构需要考虑的一个重点是XHTML代码语义化,就算是在无任何CSS样式修饰的情况下也能给他人在阅读时带来便利,甚至可以夸张点说在搜索引擎
- 描述Python strip() 方法用于移除字符串头尾指定的字符(默认为空格)。语法strip()方法语法:str.strip([char
- 网页编程中,在与数据库打交道的时候我们经常会碰到乱码的经常。本文就将介绍一种ASP读取MySQL数据库出现乱码的解决办法。情景再现:使用My
- Python自身作为一门编程语言,它有多种实现。这里的实现指的是符合Python语言规范的Python解释程序以及标准库等。这些实现虽然实现
- 文字的多行处理在dom元素中很好办。但是canvas中没有提供方法,只有通过截取指定字符串来达到目的。那么下面就介绍我自己处理的办法:wxm
- vue 页面卡死,点击无反应我在结合element做表单的时候,进入编辑页时,点击切换不生效,但是value值已改变,就是view视图层无反
- 在 JavaScript 中,可以用 instanceof 来判断一个对象是不是某个类或其子类的实例。比如:// 代码
- 本文实例讲述了sql server实现分页的方法。分享给大家供大家参考,具体如下:declare @index int,@num intse