PyTorch零基础入门之逻辑斯蒂回归
作者:山顶夕景 发布时间:2022-09-09 03:44:25
标签:PyTorch,逻辑斯蒂回归,PyTorch
学习总结
(1)和上一讲的模型训练是类似的,只是在线性模型的基础上加个sigmoid,然后loss函数改为交叉熵BCE函数(当然也可以用其他函数),另外一开始的数据y_data也从数值改为类别0和1(本例为二分类,注意x_data
和y_data
这里也是矩阵的形式)。
一、sigmoid函数
logistic function是一种sigmoid函数(还有其他sigmoid函数),但由于使用过于广泛,pytorch默认logistic function叫为sigmoid函数。还有如下的各种sigmoid函数:
二、和Linear的区别
逻辑斯蒂和线性模型的unit区别如下图:
sigmoid
函数是不需要参数的,所以不用对其初始化(直接调用nn.functional.sigmoid
即可)。
另外loss函数从MSE改用交叉熵BCE:尽可能和真实分类贴近。
如下图右方表格所示,当 y ^ \hat{y} y^越接近y时则BCE Loss值越小。
三、逻辑斯蒂回归(分类)PyTorch实现
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021
@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
# 准备数据
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
losslst = []
class LogisticRegressionModel(nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
# 和线性模型的网络的唯一区别在这句,多了F.sigmoid
y_predict = F.sigmoid(self.linear(x))
return y_predict
model = LogisticRegressionModel()
# 使用交叉熵作损失函数
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(),
lr = 0.01)
# 训练
for epoch in range(1000):
y_predict = model(x_data)
loss = criterion(y_predict, y_data)
# 打印loss对象会自动调用__str__
print(epoch, loss.item())
losslst.append(loss.item())
# 梯度清零后反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 画图
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()
# test
# 每周学习的时间,200个点
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 画 probability of pass = 0.5的红色横线
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
可以看出处于通过和不通过的分界线是Hours=2.5。
Reference
pytorch官方文档
来源:https://blog.csdn.net/qq_35812205/article/details/120820397


猜你喜欢
- ElasticSearch是一个基于Lucene的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎,基于RESTful web接口。E
- 像素误差看自己设计好上线的网站,偶尔会发觉像素行间出现了弹性空间,总在不经意间蹦出一定的差距。有些页面很难发现,比如活动类页面,这类页面多呈
- Python中强大的选项处理模块。关于Python之OptionParser模块使用详解可以参考这篇。示例#!/usr/bin/python
- 做教育业的网站,会将此遇到这个问题:如何在网页上显示音标?音标为什么显示为乱字符?等等类似的问题。前两天做沪江网某英语页面的时候也碰到了这个
- python提供了json包来进行json处理,json与python中数据类型对应关系如下:一个python object无法直接与jso
- 在VBScript中有Filter这个函数可以用来对数组进行过滤,并返回原数组的一个子集数组。语法说明: 引用内容Filter 函
- 我们将研究一种判别式分类方法,其中直接学习评估 g(x)所需的 w 参数。我们将使用感知器学习算法。感知器学习算法很容易实现,但为了节省时间
- 自动化整理计算机文件通过Python编程完成文件的自动分类、文件和文件夹的快速查找、重复文件的清理、图片格式的转换等常见工作。1. 文件的自
- 我差不多是与做web design的同时接触的flash design,因为那会普遍认为flash神通广大、无所不能。这些年我看Adobe的
- 本文实例讲述了python使用datetime模块计算各种时间间隔的方法。分享给大家供大家参考。具体分析如下:python中通过dateti
- 在进行python的开发过程中一直倡导使用虚拟环境来进行项目隔离,这样不会因为python的包不同而导致各种问题,但是以往为了图省事简单,安
- 1.match() 从开始位置开始匹配 2.search() 任意位置匹配,如果有多个匹配,只返回第一个 3.finditer() 返回所有
- 最近帮伙计做了一个从网页抓取股票信息并把相应信息存入MySQL中的程序。使用环境:Python 2.5 for WindowsMySQLdb
- 数据库的表Info,表部分结构:Info_Id  
- 一、数据描述数据集中9994条数据,横跨1237天,销售额为2,297,200.8603美元,利润为286,397.0217美元,他们的库存
- 创建项目scrapy startproject zhaoping创建爬虫cd zhaopingscrapy genspider hr zha
- 文件对象提供了 read() 方法来按字节或字符读取文件内容,到底是读取宇节还是字符,则取决于是否使用了 b 模式,如果使用了 b 模式,则
- 在使用json.dumps时要注意一个问题>>> import json>>> print json.d
- 前言关于mockjs,官网描述的是1.前后端分离2.不需要修改既有代码,就可以拦截 Ajax 请求,返回模拟的响应数据。3.数据类型丰富4.
- 本文实例为大家分享了使用XML配置c3p0数据库连接池的具体代码,供大家参考,具体内容如下想通过JDBC来配置c3p0数据库连接池,上网想找