JuneoXIE
5/19/2019 - 12:41 PM

Pytorch常用工具

torch 模块

torch基本函数汇总

数据(图像)操作 torchvision.transforms 模块

1. PIL.Image/numpy.ndarray与Tensor的相互转化:transforms.ToTensor()

from torchvision import transforms
transforms.ToTensor()

2. 对torch.*Tensor进行归一化

transforms.Normalize(mean, std)

此转换类作用于torch.*Tensor。给定均值mean(R, G, B)和标准差str(R, G, B),用公式channel = (channel - mean) / std进行规范化。 是对tensor进行归一化,所以需要放在transforms.ToTensor()之后

3. 对PIL.Image进行裁剪、缩放等操作

transforms.Scale(256),
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip()

4. 进行多种操作的组合

对图片进行多种操作。这个时候,需要用transforms.Compose(transforms)将多个transform组合起来使用

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
all_transforms = transforms.Compose([
                    transforms.Scale(256),
                    transforms.RandomSizedCrop(224),
                    transforms.RandomHorizontalFlip(), # 对PIL.Image图片进行操作
                    transforms.ToTensor(),
                    normalize])

5. PIL.Image/numpy.ndarray与Tensor的相互转换

#-*-coding:utf-8-*-
import torch
from torchvision import transforms
from PIL import Image
import cv2

img_path = "./cat.59.jpg"  

transform1 = transforms.Compose([
    transforms.CenterCrop((224,224)), # 对PIL图片进行裁剪
    transforms.ToTensor(), # 转化为Tensor
    ])

## PIL图片转化为Tensor
img_PIL = Image.open(img_path).convert('RGB')
img_PIL.show() # 原始图片
img_PIL_Tensor = transform1(img_PIL)
print(type(img_PIL))
print(type(img_PIL_Tensor))

#Tensor转成PIL.Image重新显示
new_img_PIL = transforms.ToPILImage()(img_PIL_Tensor).convert('RGB')
new_img_PIL.show() # 处理后的PIL图片

## opencv读取的图片与Tensor互转
# transforms中,没有对np.ndarray格式图像的操作
img_cv = cv2.imread(img_path)
transform2 = transforms.Compose([
    transforms.ToTensor(), 
    ])

img_cv_Tensor = transform2(img_cv)
print(type(img_cv))
print(type(img_cv_Tensor))