python深度学习之多标签分类器及pytorch实现源码
作者:鬼道2022 发布时间:2022-09-26 01:09:12
标签:多标签,分类器,pytorch,源码
多标签分类器
多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:
类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云
如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。
多标签分类器损失函数
代码实现
针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别1的多标签可以为[1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.Sq1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # (16, 14, 14)
)
self.Sq2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14)
nn.ReLU(),
nn.MaxPool2d(2), # (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 100)
def forward(self, x):
x = self.Sq1(x)
x = self.Sq2(x)
x = x.view(x.size(0), -1)
x = self.out(x)
## Sigmoid activation
output = F.sigmoid(x) # 1/(1+e**(-x))
return output
def loss_fn(pred, target):
return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()
def multilabel_generate(label):
Y1 = F.one_hot(label, num_classes = 100)
Y2 = F.one_hot(label+10, num_classes = 100)
Y3 = F.one_hot(label+50, num_classes = 100)
multilabel = Y1+Y2+Y3
return multilabel
# def multilabel_generate(label):
# multilabel_dict = {}
# multi_list = []
# for i in range(label.shape[0]):
# multi_list.append(multilabel_dict[label[i].item()])
# multilabel_tensor = torch.tensor(multi_list)
# return multilabel
def train():
epoches = 10
mnist_net = CNN()
mnist_net.train()
opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)
for epoch in range(epoches):
loss = 0
for batch_X, batch_Y in train_loader:
opitimizer.zero_grad()
outputs = mnist_net(batch_X)
loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]
loss.backward()
opitimizer.step()
print(loss)
if __name__ == '__main__':
train()
来源:https://guidao.blog.csdn.net/article/details/122085474


猜你喜欢
- 最近两周由于忙于个人项目,一直未发言了,实在是太荒凉了。。。。,上周由于项目,见到Python的应用极为广泛,用起来也特别顺手,于是小编也开
- 本文记录了python安装及环境配置方法,具体内容如下Python安装 Windowns操作系统中安装Python步骤一 下载安装包从Pyt
- 大名鼎鼎的FCKeditor终于在最近发布新版本了,与增加版本号不同,这次完全把它改名了,更名为CKeditor。这应该是和它的开发公司CK
- 先上一张效果图:以前用angularjs操作基本上都是要取到每个列表的id再循环判断是不是当前点击的列表来显示折叠。今天在这个项目 http
- 配置Laravel 的邮件服务可以通过 config/mail.php 配置文件进行配置。邮件中的每一项都在配置文件中有单独的配置项,甚至是
- 最近在学习正则,一些比较有用的东西怕忘记,记下来,比较乱,想一条记录一条:正则表达式在线测试//匹配文本,这个偶尔比较好用,但是要小心字符中
- 需要将字符串中的空格去掉的情况,可以使用下面几种解决方法:1、strip()方法:该方法只能把字符串头和尾的空格去掉,但是不能将字符串中间的
- 统计每天的数据量变化,数据量变动超过一定范围时,进行告警。告警通过把对应的参数传递至相应接口。python程序如下#!/usr/bin/py
- 查看两个数据库的同名表的字段名差异问题描述开发过程中有多个测试环境,测试环境 A 加了字段,测试环境 B 忘了加,字段名对不上,同一项目就报
- 本文实例讲述了Python实现字典去除重复的方法。分享给大家供大家参考,具体如下:#!/usr/bin/env python# encodi
- 如下所示:INPUT = c_int * 4# 实例化一个长度为2的整型数组input = INPUT()# 为数组赋值(input这个数组
- 1、查看当前所有连接的详细资料:./mysqladmin -uadmin -p -h10.140.1.1 processlist2、只查看当
- 1. 引言本文是数独游戏问题求解的第二篇,在前文中我们使用回溯算法实现了最简单版本的数独游戏求解方案。本文主要在前文解决方案的基础上,来思考
- 安装好jupyter notebook后,在pycharm中无论运行什么样的python脚本,都会默认使用python的console运行,
- Timer继承子Thread类,是Thread的子类,也是线程类,具有线程的能力和特征。这个类用来定义多久执行一个函数。它的实例是能够延迟执
- 1. 手动操作1.1. 显示模块pip list1.2. 显示过期模块pip list --outdated1.3. 安装模块pip ins
- 这篇文章主要介绍了Python文本处理简单易懂方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 文字的多行处理在dom元素中很好办。但是canvas中没有提供方法,只有通过截取指定字符串来达到目的。那么下面就介绍我自己处理的办法:wxm
- 1.概述Kivy是一套Python下的跨平台开源应用开发框架,官网,我们可以用它来将Python程序打包为安卓的apk安装文件。以下是在wi
- 如何使用ADO 2x Command 对象读取数据?具体的读数据代码如下:Cmd = CType(EC.Example1