基于numpy实现逻辑回归
作者:Giao哥不瘦到100不改名 发布时间:2023-06-21 10:04:25
标签:numpy,逻辑回归
本文实例为大家分享了基于numpy实现逻辑回归的具体代码,供大家参考,具体内容如下
交叉熵损失函数;sigmoid激励函数
基于numpy的逻辑回归的程序如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets.samples_generator import make_classification
class logistic_regression():
def __init__(self):
pass
def sigmoid(self, x):
z = 1 /(1 + np.exp(-x))
return z
def initialize_params(self, dims):
W = np.zeros((dims, 1))
b = 0
return W, b
def logistic(self, X, y, W, b):
num_train = X.shape[0]
num_feature = X.shape[1]
a = self.sigmoid(np.dot(X, W) + b)
cost = -1 / num_train * np.sum(y * np.log(a) + (1 - y) * np.log(1 - a))
dW = np.dot(X.T, (a - y)) / num_train
db = np.sum(a - y) / num_train
cost = np.squeeze(cost)#[]列向量,易于plot
return a, cost, dW, db
def logistic_train(self, X, y, learning_rate, epochs):
W, b = self.initialize_params(X.shape[1])
cost_list = []
for i in range(epochs):
a, cost, dW, db = self.logistic(X, y, W, b)
W = W - learning_rate * dW
b = b - learning_rate * db
if i % 100 == 0:
cost_list.append(cost)
if i % 100 == 0:
print('epoch %d cost %f' % (i, cost))
params = {
'W': W,
'b': b
}
grads = {
'dW': dW,
'db': db
}
return cost_list, params, grads
def predict(self, X, params):
y_prediction = self.sigmoid(np.dot(X, params['W']) + params['b'])
#二分类
for i in range(len(y_prediction)):
if y_prediction[i] > 0.5:
y_prediction[i] = 1
else:
y_prediction[i] = 0
return y_prediction
#精确度计算
def accuracy(self, y_test, y_pred):
correct_count = 0
for i in range(len(y_test)):
for j in range(len(y_pred)):
if y_test[i] == y_pred[j] and i == j:
correct_count += 1
accuracy_score = correct_count / len(y_test)
return accuracy_score
#创建数据
def create_data(self):
X, labels = make_classification(n_samples=100, n_features=2, n_redundant=0, n_informative=2)
labels = labels.reshape((-1, 1))
offset = int(X.shape[0] * 0.9)
#训练集与测试集的划分
X_train, y_train = X[:offset], labels[:offset]
X_test, y_test = X[offset:], labels[offset:]
return X_train, y_train, X_test, y_test
#画图函数
def plot_logistic(self, X_train, y_train, params):
n = X_train.shape[0]
xcord1 = []
ycord1 = []
xcord2 = []
ycord2 = []
for i in range(n):
if y_train[i] == 1:#1类
xcord1.append(X_train[i][0])
ycord1.append(X_train[i][1])
else:#0类
xcord2.append(X_train[i][0])
ycord2.append(X_train[i][1])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(xcord1, ycord1, s=32, c='red')
ax.scatter(xcord2, ycord2, s=32, c='green')#画点
x = np.arange(-1.5, 3, 0.1)
y = (-params['b'] - params['W'][0] * x) / params['W'][1]#画二分类直线
ax.plot(x, y)
plt.xlabel('X1')
plt.ylabel('X2')
plt.show()
if __name__ == "__main__":
model = logistic_regression()
X_train, y_train, X_test, y_test = model.create_data()
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
# (90, 2)(90, 1)(10, 2)(10, 1)
#训练模型
cost_list, params, grads = model.logistic_train(X_train, y_train, 0.01, 1000)
print(params)
#计算精确度
y_train_pred = model.predict(X_train, params)
accuracy_score_train = model.accuracy(y_train, y_train_pred)
print('train accuracy is:', accuracy_score_train)
y_test_pred = model.predict(X_test, params)
accuracy_score_test = model.accuracy(y_test, y_test_pred)
print('test accuracy is:', accuracy_score_test)
model.plot_logistic(X_train, y_train, params)
结果如下所示:
来源:https://blog.csdn.net/exsolar_521/article/details/108206644


猜你喜欢
- 最近的项目中大量涉及数据的预处理工作,对于ndarray的使用非常频繁。其中ndarray如何进行数值筛选,总结了几种方法。1.按某些固定值
- 根据“廖雪峰”的教程进行python学习,计划每天抽出1-2个小时的时间进行充电。Python是著名的“龟叔”Guido van Rossu
- 一、模型参数的保存和加载 torch.save(module.state_dict(), path):使用module.state
- 先看看弹窗效果,如下: //放大图片 $(page).on('click','.popupIm
- 本文利用Python3启动简单的HTTP服务器,以实现在同一网络中共享本地文件。启动HTTP服务器打开终端,转入目标文件所在文件夹,键入以下
- 关于段落<p></p>相信大家已经都在自己的工作中开始关注并应用了。因为那真的是非常简单的事,只要你愿意你随时都可以
- 第1章 新建工程和创建app新建工程和创建app就不用贴出来了,我这里是测试图片上传的功能能否实现,所以项目都是新的,正常在以有的app下就
- function test(){ return 123; } 显然这是一个函数声明,那下面的呢 var b=f
- 写这个的目地,主要是系统理下目前产品设计的流程,提醒自己尽量去避免一些常见的问题,也能让大家系统的了解天极网的产品设计流程。当然,不保证任何
- 1.迭代器当您创建一个列表时,你可以逐个读取它的项。逐项读取其项称为迭代:mylist是一个可迭代的对象。当你使用列表解析式时,你创建了一个
- 在记忆里,关于时间方面常的SQL也就下面这两个了,大多数朋友问题中所涉及到的数据库都ACCESS的,在些,也就写出这两SQL了。年代久远,目
- 纳什均衡是一种博弈论中的概念,它描述了一种平衡状态,其中每个参与者都不能通过独立改变其决策来提高自己的利益。在 Python 中,可以使用一
- 在平时的工作中,我们经常会遇到需要批量创建文件的情况,例如,汇总一个月中每天回复问题的文件等,这里,我们以如何使用当前日期时间创建文件为例:
- 本文实例讲述了纯JavaScript实现的分页插件。分享给大家供大家参考。具体如下://总条数(必填)var Num=Number(<
- 代码如下:--函数 CREATE function fn_GetPy(@str nvarchar(4000)) returns nvarch
- 1. show variables like '%profiling%';(查看profiling信息) &nbs
- 本文实例讲述了Thinkphp5.0 框架实现控制器向视图view赋值及视图view取值操作。分享给大家供大家参考,具体如下:Thinkph
- 本文实例为大家分享了vue实现下拉菜单树的具体代码,供大家参考,具体内容如下效果:使用 Vue-Treeselect 实现建议通过npm安装
- mysql版本:8.0.28xtrabackup版本:8.0.281、安装xtrabackup下载地址:Download Percona X
- fmtfmt是go语言中的格式化输入输出库,其中主要分为两个部分,分别是输出部分和输入部分。输出PrintPrint函数的主要功能是输出,和