小言_互联网的博客

帮一下人工智能

468人阅读  评论(0)

帮忙看下代码人工智能的最好是每一句都解释下
帮忙看下代码人工智能的最好是每一句都解释下
帮忙看下代码人工智能的最好是每一句都解释下
帮忙看下代码人工智能的最好是每一句都解释下

import torch
import torchvision //导入patorch里面的Torchvision包
from torchvision import datasets,transforms, models //torchvision里面的三个子包
import os //通过os模块调用系统命令
import numpy as np //导入numpy并且命名为np
import matplotlib.pyplot as plt
from torch.autograd import Variable
import time
%matplotlib inline

path = “data/catsvsdog”
transform = transforms.Compose([transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]) //图像处理

data_image = {x:datasets.ImageFolder(root = os.path.join(path,x),
transform = transform)
for x in [“train”, “val”]} //图像数据集上的图像进行图像处理

data_loader_image = {x:torch.utils.data.DataLoader(dataset=data_image[x],
batch_size = 4,
shuffle = True)
for x in [“train”, “val”]} //加载图片做数据集

classes = data_image[“train”].classes
classes_index = data_image[“train”].class_to_idx
print(classes)
print(classes_index)
print(len(data_image[“train”]))
print(len(data_image[“val”]))
print(model)
use_gpu = torch.cuda.is_available()
print(use_gpu)

classes = data_image[“train”].classes
classes_index = data_image[“train”].class_to_idx
print(classes)
print(classes_index)

print(u"训练集个数:", len(data_image[“train”]))
print(u"验证集个数:", len(data_image[“val”]))

X_train, y_train = next(iter(data_loader_image[“train”]))
mean = [0.5,0.5,0.5]
std = [0.5,0.5,0.5]
img = torchvision.utils.make_grid(X_train)
img = img.numpy().transpose((1,2,0))
img = img*std+mean

print([classes[i] for i in y_train])
plt.imshow(img)

model = models.vgg16(pretrained=True)
print(model)

for parma in model.parameters():
parma.requires_grad = False

model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 2))

if use_gpu:
model = model.cuda()

cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters())
print(model)

n_epochs = 1
for epoch in range(n_epochs):
since = time.time()
print(“Epoch{}/{}”.format(epoch, n_epochs))
print("-"*10)
for param in [“train”, “val”]:
if param == “train”:
model.train = True
else:
model.train = False

    running_loss = 0.0
    running_correct = 0 
    batch = 0
    for data in data_loader_image[param]:
        batch += 1
        X, y = data
        if use_gpu:
            X, y  = Variable(X.cuda()), Variable(y.cuda())
        else:
            X, y = Variable(X), Variable(y)
    
        optimizer.zero_grad()
        y_pred = model(X)
        _, pred = torch.max(y_pred.data, 1)
    
        loss = cost(y_pred, y)
        if param =="train":
            loss.backward()
            optimizer.step()
        running_loss += loss.item()  
        running_correct += torch.sum(pred == y.data)
        if batch%10 == 0 and param =="train":
            print("Batch {}, Train Loss:{:.4f}, Train ACC:{:.4f}".format(
                  batch, running_loss/(4*batch), 100*running_correct/(4*batch)))
        
    epoch_loss = running_loss/len(data_image[param])
    epoch_correct = 100*running_correct/len(data_image[param])

    print("{}  Loss:{:.4f},  Correct{:.4f}".format(param, epoch_loss, epoch_correct))
now_time = time.time() - since   
print("Training time is:{:.0f}m {:.0f}s".format(now_time//60, now_time%60))

torch.save(model.state_dict(), “model_vgg16_finetune.pkl”)

data_test_img = datasets.ImageFolder(root=“data/catsvsdog/val”,
transform = transform)
data_loader_test_img = torch.utils.data.DataLoader(dataset=data_test_img,
batch_size = 4)

image, label = next(iter(data_loader_test_img))
images = Variable(image.cuda())
y_pred = model(images)
_,pred = torch.max(y_pred.data, 1)
print(pred)

img = torchvision.utils.make_grid(image)
img = img.numpy().transpose(1,2,0)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
img = img * std + mean
print(“Pred Label:”, [classes[i] for i in pred])
plt.imshow(img)

如果你想加载一篇你写过的.md文件或者.html文件,在上方工具栏可以选择导入功能进行对应扩展名的文件导入,
继续你的创作。


转载:https://blog.csdn.net/YangZhiYuanroman/article/details/89387699
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场