python的numpy模块实现逻辑回归模型
作者:上进的小菜鸟 发布时间:2022-10-01 07:05:59
标签:python,numpy,逻辑回归
使用python的numpy模块实现逻辑回归模型的代码,供大家参考,具体内容如下
使用了numpy模块,pandas模块,matplotlib模块
1.初始化参数
def initial_para(nums_feature):
"""initial the weights and bias which is zero"""
#nums_feature是输入数据的属性数目,因此权重w是[1, nums_feature]维
#且w和b均初始化为0
w = np.zeros((1, nums_feature))
b = 0
return w, b
2.逻辑回归方程
def activation(x, w , b):
"""a linear function and then sigmoid activation function:
x_ = w*x +b,y = 1/(1+exp(-x_))"""
#线性方程,输入的x是[batch, 2]维,输出是[1, batch]维,batch是模型优化迭代一次输入数据的数目
#[1, 2] * [2, batch] = [1, batch], 所以是w * x.T(x的转置)
#np.dot是矩阵乘法
x_ = np.dot(w, x.T) + b
#np.exp是实现e的x次幂
sigmoid = 1 / (1 + np.exp(-x_))
return sigmoid
3.梯度下降
def gradient_descent_batch(x, w, b, label, learning_rate):
#获取输入数据的数目,即batch大小
n = len(label)
#进行逻辑回归预测
sigmoid = activation(x, w, b)
#损失函数,np.sum是将矩阵求和
cost = -np.sum(label.T * np.log(sigmoid) + (1-label).T * np.log(1-sigmoid)) / n
#求对w和b的偏导(即梯度值)
g_w = np.dot(x.T, (sigmoid - label.T).T) / n
g_b = np.sum((sigmoid - label.T)) / n
#根据梯度更新参数
w = w - learning_rate * g_w.T
b = b - learning_rate * g_b
return w, b, cost
4.模型优化
def optimal_model_batch(x, label, nums_feature, step=10000, batch_size=1):
"""train the model with batch"""
length = len(x)
w, b = initial_para(nums_feature)
for i in range(step):
#随机获取一个batch数目的数据
num = randint(0, length - 1 - batch_size)
x_batch = x[num:(num+batch_size), :]
label_batch = label[num:num+batch_size]
#进行一次梯度更新(优化)
w, b, cost = gradient_descent_batch(x_batch, w, b, label_batch, 0.0001)
#每1000次打印一下损失值
if i%1000 == 0:
print('step is : ', i, ', cost is: ', cost)
return w, b
5.读取数据,数据预处理,训练模型,评估精度
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from random import randint
from sklearn.preprocessing import StandardScaler
def _main():
#读取csv格式的数据data_path是数据的路径
data = pd.read_csv('data_path')
#获取样本属性和标签
x = data.iloc[:, 2:4].values
y = data.iloc[:, 4].values
#将数据集分为测试集和训练集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=0)
#数据预处理,去均值化
standardscaler = StandardScaler()
x_train = standardscaler.fit_transform(x_train)
x_test = standardscaler.transform(x_test)
#w, b = optimal_model(x_train, y_train, 2, 50000)
#训练模型
w, b = optimal_model_batch(x_train, y_train, 2, 50000, 64)
print('trian is over')
#对测试集进行预测,并计算精度
predict = activation(x_test, w, b).T
n = 0
for i, p in enumerate(predict):
if p >=0.5:
if y_test[i] == 1:
n += 1
else:
if y_test[i] == 0:
n += 1
print('accuracy is : ', n / len(y_test))
6.结果可视化
predict = np.reshape(np.int32(predict), [len(predict)])
#将预测结果以散点图的形式可视化
for i, j in enumerate(np.unique(predict)):
plt.scatter(x_test[predict == j, 0], x_test[predict == j, 1],
c = ListedColormap(('red', 'blue'))(i), label=j)
plt.show()
来源:https://blog.csdn.net/qq_35153620/article/details/95763896


猜你喜欢
- next()方法当一个文件被用作迭代器,典型例子是在一个循环中被使用,next()方法被反复调用。此方法返回下一个输入行,或引发
- 一、安装redis因为是在CentOS系统下安装的,并且是服务器。遇到的困难有点多不过。1.首先要下载相关依赖首先先检查是否有c语言的编译环
- 设计与开发之间本有一线界限,但当时代步入又一个十年,这个线变得更加模糊甚至感觉不到它的存在。使用PS设计网页版面,足矣?或许五年前是吧!现在
- 曾经有许多创造性的logo设计案例,logo设计资源和logo设计指导张贴在互联网的各个角落。这些帮助会为你的logo设计创造一个功能强大的
- 1.常用数据结构之列表我们先给大家一个编程任务,将一颗色子掷6000次,统计每个点数出现的次数。这个任务对大家来说应该是非常简单的,我们可以
- 在学习python代码时,看到有的类的方法中第一参数是cls,有的是self,经过了解得知,python并没有对类中方法的第一个参数名字做限
- 前言大部分人在日常的业务开发中,其实很少去关注数据库的事务相关问题,基本上都是 CURD 一把梭。正好最近在看 MySQL 的相关基础知识,
- Kubernetes的控制器模式是其非常重要的一个设计模式,整个Kubernetes定义的资源对象以及其状态都保存在etcd数据库中,通过a
- 在使用matplotlib画图时,少不了对性能图形做出一些说明和补充。一般情况下,loc属性设置为'best'就足够应付了p
- 一、Jenkins 是什么?Jenkins是一款开源 CI&CD 软件,用于自动化各种任务,包括构建、测试和部署软件。二、准备工作安
- DW2004的中文乱码情况你遇到过么?乱码一般是怎么出现的呢?也许很多时候用其他软件(比如Editplus)写程序的时候,忘了meta标签里
- Log包Go语言提供的默认日志包:https://golang.org/pkg/log/基本用法log包定义了Logger类型,该类型提供了
- 方式1:引入普通的js文件,如user.js1.1、属性和方法都写在一个变量内部const user={ logi
- 近期线上出现一个bug,研发的小伙伴把测试环境的地址写死到代码中,在上线前忘记修改,导致线上发布的代码中使用了测试环境地址。开发过程中虽然有
- 使用PHP GD,使用良好,一键剪裁各种尺寸,打包下载。经常换icon的懂的,美工给你一个1024的logo,你得ps出各种尺寸,于是有了这
- 看到很多站长工具网,都提供了通过域名获取网站IP的方法。自己也想做一个,网上查了不少代码。有说用WSHSHELL,也有说用ASPPING组件
- 需要处理原始的音频,所以给服务器的环境安装librosa的包pip install librosa直接pip install librosa
- 这个操作现在看来真没啥难的,但是我找相关的资料真的找了好久。多数大佬都是直接pandas官网甩我脸上,然后举一个入门级的例子。https:/
- python2:print语句,语句就意味着可以直接跟要打印的东西,如果后面接的是一个元组对象,直接打印python3:print函数,函数
- 前言项目开发中,产品经理提了这样一个需求:将系统中的附件实现批量打包下载功能。本来系统中是有单个下载及批量下载功能,现在应业务方的需求,需要