PyTorch零基础入门之逻辑斯蒂回归
作者:山顶夕景 发布时间:2022-09-09 03:44:25
标签:PyTorch,逻辑斯蒂回归,PyTorch
学习总结
(1)和上一讲的模型训练是类似的,只是在线性模型的基础上加个sigmoid,然后loss函数改为交叉熵BCE函数(当然也可以用其他函数),另外一开始的数据y_data也从数值改为类别0和1(本例为二分类,注意x_data
和y_data
这里也是矩阵的形式)。
一、sigmoid函数
logistic function是一种sigmoid函数(还有其他sigmoid函数),但由于使用过于广泛,pytorch默认logistic function叫为sigmoid函数。还有如下的各种sigmoid函数:
二、和Linear的区别
逻辑斯蒂和线性模型的unit区别如下图:
sigmoid
函数是不需要参数的,所以不用对其初始化(直接调用nn.functional.sigmoid
即可)。
另外loss函数从MSE改用交叉熵BCE:尽可能和真实分类贴近。
如下图右方表格所示,当 y ^ \hat{y} y^越接近y时则BCE Loss值越小。
三、逻辑斯蒂回归(分类)PyTorch实现
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021
@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
# 准备数据
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
losslst = []
class LogisticRegressionModel(nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
# 和线性模型的网络的唯一区别在这句,多了F.sigmoid
y_predict = F.sigmoid(self.linear(x))
return y_predict
model = LogisticRegressionModel()
# 使用交叉熵作损失函数
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(),
lr = 0.01)
# 训练
for epoch in range(1000):
y_predict = model(x_data)
loss = criterion(y_predict, y_data)
# 打印loss对象会自动调用__str__
print(epoch, loss.item())
losslst.append(loss.item())
# 梯度清零后反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 画图
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()
# test
# 每周学习的时间,200个点
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 画 probability of pass = 0.5的红色横线
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
可以看出处于通过和不通过的分界线是Hours=2.5。
Reference
pytorch官方文档
来源:https://blog.csdn.net/qq_35812205/article/details/120820397
0
投稿
猜你喜欢
- 目标网址:https://www.baidu.com/要获取的内容:链接分析:从下图可以看出只需要获取关键字,再构建就可以了。完整代码:im
- 本文实例讲述了python类和对象用法。分享给大家供大家参考,具体如下:前面我们都是用python面向过程编程,现在来用python创建类和
- 本文实例为大家分享了php微信公众号开发之快递查询的具体代码,供大家参考,具体内容如下快递查询数组用法foreach查询接口是:爱快递:ht
- XHTML规范中有一条标准就是“每个XHTML标签都有一个结束标记”。那么对于HTML中原来不带结束标记的元素,则在该结束前加上“/”来关闭
- 本文实例讲述了PHP实现登录,注册及密码修改功能的方法。分享给大家供大家参考,具体如下:这里介绍注册,登录,修改密码的界面布局与功能实现:1
- 前言如题目所述,又是花费了两天的时间实现了该功能,本来今天下午有些心灰意冷,打算放弃嵌入到Scoll Area中的想法,但最后还是心里一紧,
- 1 将文件保存到服务器本地upload.html<!DOCTYPE html><html lang="en&qu
- 前不久微信上线了拍一拍功能,刚推出就被有才的网友玩坏了。还有更多没有节操的拍法这里就不展示了。但拍一拍属于弱提示,只有在聊天界面才能感受到。
- python eval函数功能:将字符串str当成有效的表达式来求值并返回计算结果。函数定义:eval(expression, global
- ASP长文章分页代码实例,也许你会问一篇文章为什么还要进行分页呢?因为文章有短有长,当你的文章很长的时候,如果就一个页面都显示出来的话,读者
- 前言今天,在网上发现一款很棒的python画图工具库。很简单的api调用就能生成漂亮的图表。并且可以进行一些互动。pyecharts 是一个
- 大多的MySQL都是装在Linux上的,而我们的本机上一般都会装MySQL-Front.那如何用MySQL-Front连接远端Linux系统
- pickle的作用:1:pickle.dump(dict,file)把字典转为二进制存入文件.2:pickle.load(file)把文件二
- 使用python写爬虫时,优选selenium,由于PhantomJS因内部原因已经停止更新,最新版的selenium已经使用headles
- ndarray的转置(transpose)对于A是由np.ndarray表示的情况:可以直接使用命令A.T。也可以使用命令A.transpo
- 如下所示:# coding:utf-8import osfrom PIL import Image# bmp 转换为jpgdef bmpTo
- 使用“发送测试电子邮件”对话框来测试使用特定配置文件发送邮件的能力。过程发送测试电子邮件1.使用对象
- 直方图处理直方图从图像内部灰度级的角度对图像进行表述从直方图的角度对图像进行处理,可以达到增强图像显示效果的目的。直方图的含义直方图是图像内
- 虽然Golang的GC自打一开始,就被人所诟病,但是经过这么多年的发展,Golang的GC已经改善了非常多,变得非常优秀了。以下是Golan
- 本文实例讲述了Python实现读取txt文件并转换为excel的方法。分享给大家供大家参考,具体如下:这里的txt文件内容格式为:892天平