defforward(self, x): y = self.conv1(x) y = self.relu1(y) y = self.pool1(y) y = self.conv2(y) y = self.relu2(y) y = self.pool2(y) y = y.view(y.shape[0], -1) y = self.fc1(y) y = self.relu3(y) y = self.fc2(y) y = self.relu4(y) y = self.fc3(y) y = self.relu5(y) return y
Dateset Loading代码:
import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader from PIL import Image
root = "E:\\"
# -----------------ready the dataset-------------------------- defdefault_loader(path): return Image.open(path).convert('L')
classMyDataset(Dataset): # 构造函数带有默认参数 def__init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: # 移除字符串首尾的换行符 # 删除末尾空 # 以空格为分隔符 将字符串分成 line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签 self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader
from model import Model import numpy as np import torch from torchvision.datasets import mnist from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from my_dataset import train_dataset, test_dataset