PyTorch - 从scratch开始​​训练Convent

  • 简述

    在本章中,我们将专注于从头开始创建 Convent 。这推断使用torch创建各自的Convent或样本神经网络。
  • 第1步

    使用相应的参数创建一个必要的类。参数包括具有随机值的权重。
    
    class Neural_Network(nn.Module):
       def __init__(self, ):
          super(Neural_Network, self).__init__()
          self.inputSize = 2
          self.outputSize = 1
          self.hiddenSize = 3
          # weights
          self.W1 = torch.randn(self.inputSize, 
          self.hiddenSize) # 3 X 2 tensor
          self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor
    
  • 第2步

    使用 sigmoid 函数创建函数的前馈模式。
    
    def forward(self, X):
       self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot" 
       does not broadcast in PyTorch
       self.z2 = self.sigmoid(self.z) # activation function
       self.z3 = torch.matmul(self.z2, self.W2)
       o = self.sigmoid(self.z3) # final activation 
       function
       return o
       def sigmoid(self, s):
          return 1 / (1 + torch.exp(-s))
       def sigmoidPrime(self, s):
          # derivative of sigmoid
          return s * (1 - s)
       def backward(self, X, y, o):
          self.o_error = y - o # error in output
          self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error
          self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2))
          self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2)
          self.W1 + = torch.matmul(torch.t(X), self.z2_delta)
          self.W2 + = torch.matmul(torch.t(self.z2), self.o_delta)
    
  • 第 3 步

    创建如下所述的训练和预测模型 -
    
    def train(self, X, y):
       # forward + backward pass for training
       o = self.forward(X)
       self.backward(X, y, o)
    def saveWeights(self, model):
       # Implement PyTorch internal storage functions
       torch.save(model, "NN")
       # you can reload model with all the weights and so forth with:
       # torch.load("NN")
    def predict(self):
       print ("Predicted data based on trained weights: ")
       print ("Input (scaled): \n" + str(xPredicted))
       print ("Output: \n" + str(self.forward(xPredicted)))