线性回归是最容易理解的一种预测方式。线性方程Y = aX+b,大家都认识。回归预测就是知道一堆X和Y的值,计算出最接近真实a和b的两个值。这个一般运用在连续的,线性的数据预测上。
这里我将举一个线性回归的例子。线性就必须是连续的数据,Y=aX+b,是最好理解的线性方程,我们随机给出一组连续的X值(100个),然后把这些数据代入Y=aX+b方程里求出100个Y,a的值我取2.5,b的值我取随机数。这样,我就得到一组经过“抖动”后的数据。并用这两组数据X和Y去训练,求出最接近的a`和b
`。
然后我再用求出的a`和b`和原始数据X套入公式,求出Y,然后画出Y和Y
`。比较两组数据差异。
>>>示例程序
import numpy as np import torch import matplotlib.pyplot as plt def prepare_db(): #生成100个点的X值 train_X = np.linspace(-2*np.pi,2*np.pi,100) #根据train_db的值,生成相应的Y=aX+b值,并进行随机加减 a=2.5 train_Y = train_X*a+np.random.rand(100) train_X = train_X.reshape(-1,1) train_Y = train_Y.reshape(-1,1) return train_X,train_Y #============================================= class LinearRegression(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(1,1) def forward(self,x): out = self.linear(x) return out #============================================== class Linear_Model(): def __init__(self): self.learning_rate = 0.001 self.epoches = 10000 self.loss_function = torch.nn.MSELoss() self.create_model() def create_model(self): self.model = LinearRegression() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) def train(self, train_X,train_Y, model_save_path="model.pth"): x = torch.tensor(train_X).float() y = torch.tensor(train_Y).float() for epoch in range(self.epoches): prediction = self.model(x) loss = self.loss_function(prediction, y) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if epoch % 1000 == 0: print("epoch: {}, loss is: {}".format(epoch, loss.item())) torch.save(self.model.state_dict(), "linear.pth") def test(self,test_db,model_open_path='model.pth'): self.model.load_state_dict(torch.load(model_open_path)) prediction = self.model(torch.tensor(test_db).float()) return prediction.detach().numpy() #================================================== if __name__ == '__main__': train_X,train_Y = prepare_db() linear = Linear_Model() linear.train(train_X,train_Y) ret = linear.test(train_X,"linear.pth") plt.plot(train_Y,'g') plt.plot(ret,'r') plt.show()
>>>运行结果

Pytorch与深度学习-02.回归预测