Python K-means实现简单图像聚类的示例代码
作者:xiongxyowo 发布时间:2023-06-30 10:40:58
标签:Python,K-means,图像聚类
这里直接给出第一个版本的直接实现:
import os
import numpy as np
from sklearn.cluster import KMeans
import cv2
from imutils import build_montages
import matplotlib.image as imgplt
image_path = []
all_images = []
images = os.listdir('./images')
for image_name in images:
image_path.append('./images/' + image_name)
for path in image_path:
image = imgplt.imread(path)
image = image.reshape(-1, )
all_images.append(image)
clt = KMeans(n_clusters=2)
clt.fit(all_images)
labelIDs = np.unique(clt.labels_)
for labelID in labelIDs:
idxs = np.where(clt.labels_ == labelID)[0]
idxs = np.random.choice(idxs, size=min(25, len(idxs)),
replace=False)
show_box = []
for i in idxs:
image = cv2.imread(image_path[i])
image = cv2.resize(image, (96, 96))
show_box.append(image)
montage = build_montages(show_box, (96, 96), (5, 5))[0]
title = "Type {}".format(labelID)
cv2.imshow(title, montage)
cv2.waitKey(0)
主要需要注意的问题是对K-Means原理的理解。K-means做的是对向量的聚类,也就是说,假设要处理的是224×224×3的RGB图像,那么就得先将其转为1维的向量。在上面的做法里,我们是直接对其展平:
image = image.reshape(-1, )
那么这么做的缺陷也是十分明显的。例如,对于两张一模一样的图像,我们将前者向左平移一个像素。这么做下来后两张图像在感官上几乎没有任何区别,但由于整体平移会导致两者的图像矩阵逐像素比较的结果差异巨大。以橘子汽车聚类为例,实验结果如下:
可以看到结果是比较差的。因此,我们进行改进,利用ResNet-50进行图像特征的提取(embedding),在特征的基础上聚类而非直接在像素上聚类,代码如下:
import os
import numpy as np
from sklearn.cluster import KMeans
import cv2
from imutils import build_montages
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torchvision import transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
resnet50 = models.resnet50(pretrained=True)
self.resnet = nn.Sequential(resnet50.conv1,
resnet50.bn1,
resnet50.relu,
resnet50.maxpool,
resnet50.layer1,
resnet50.layer2,
resnet50.layer3,
resnet50.layer4)
def forward(self, x):
x = self.resnet(x)
return x
net = Net().eval()
image_path = []
all_images = []
images = os.listdir('./images')
for image_name in images:
image_path.append('./images/' + image_name)
for path in image_path:
image = Image.open(path).convert('RGB')
image = transforms.Resize([224,244])(image)
image = transforms.ToTensor()(image)
image = image.unsqueeze(0)
image = net(image)
image = image.reshape(-1, )
all_images.append(image.detach().numpy())
clt = KMeans(n_clusters=2)
clt.fit(all_images)
labelIDs = np.unique(clt.labels_)
for labelID in labelIDs:
idxs = np.where(clt.labels_ == labelID)[0]
idxs = np.random.choice(idxs, size=min(25, len(idxs)),
replace=False)
show_box = []
for i in idxs:
image = cv2.imread(image_path[i])
image = cv2.resize(image, (96, 96))
show_box.append(image)
montage = build_montages(show_box, (96, 96), (5, 5))[0]
title = "Type {}".format(labelID)
cv2.imshow(title, montage)
cv2.waitKey(0)
可以发现结果明显改善:
来源:https://blog.csdn.net/qq_40714949/article/details/120854418
0
投稿
猜你喜欢
- string模块可以追溯到早期版本的Python。以前在本模块中实现的许多功能已经转移到str物品。这个string模块保留了几个有用的常量
- 最近对H1的讨论很多(在文章内容页中),大致有以下两种情况:H1应该用于文章的标题上H1应该用于站点的标题上相信大多数人都偏向第一种方式:用
- PHP crc32() 函数实例输出 crc32() 的结果:<?php $str = crc32("Hello World
- Microsoft Access 数据库 (.mdb) 文件大小2 G 字节。不
- 背景:pony是公司的首席体验官、首席产品经理。这次在产品峰会上pony将自己平时经验的积累与大家交流,体验较细。这次分享研发管理部,设计中
- <ScriptRUNAT=SERVERLanguage=VBScript>SubApplication_OnStar
- 前言在进行业务数据分析时,往往需要使用pandas计算环比、同比及增长率等指标,为了能够更加方便的进行的统计数据,整理方法如下。1.数据准备
- 学习目的: 掌握ADO.NET打开SQL SERVER数据库的方法。 今天做个非常普通的例子,做一个用户登录框。主要是通过这个练习认识一下S
- 论坛有人问起如何获取读取CSS属性值,就写了下面这段兼容各浏览器的获取HTML元素的css属性值函数:function getSt
- 本文实例为大家分享了tkinter实现页面跳转的具体代码,供大家参考,具体内容如下主函数main.pyfrom tkinter import
- 1 拷贝下面的代码到一个文件,并命名为forkcore.pyimport osimport threadingimport selectim
- 如图,今天跑代码的事后遇到的问题,pycharm导入我自己写的各种函数.py文件时有红色标注,显示“no moudle balabala…”
- 一、何谓ASP缓存/为什么要缓存当你的web站点采用asp技术建立的初期,可能感觉到的是asp * 页技术带来的便利性,以及随意修改性、自如
- 在Python我们要判断一个文件对当前用户有没有读、写、执行权限,我们通常可以使用os.access函数来实现,比如:# 判断读权限os.a
- 通常我们会在一些javascript的书籍上看到,使用Javascript保留字作为标识符(变量名、函数名、循环标记等)时,会引起程序报错!
- tornado 默认有一个模板引擎但是功能简单(其实我能用到的都差不多)使用起来颇为麻烦, 而jinja2语法与django模板相似所以决定
- 网页设计中的脏、乱、差,是我们在设计过程中常会遇到的问题。通常"脏"是由对色彩使用不当所产生的,而色彩使用不当产生的不好
- 一、概念1、模块化代码可以使代码易于维护和调试,并且提高代码的重用性;2、函数可以用来减少冗余的代码并提高代码的可重用性。函数也可以用来模块
- (1) 我们先用arange函数创建一个数组并改变其维度,使之变成一个三维数组:>>> a = np.arange(24)
- Python 对代码的缩进要求非常严格,同一个级别代码块的缩进量必须一样,否则解释器会报 SyntaxError 异常错误。在 Python