LSTM入门正弦波序列预测
代码说明:
LSTM入门学习,正弦波序列预测 for i, input_t in enumerate(input.chunk(input.size(1), dim=1)): h_t, c_t = self.lstm1(input_t, (h_t, c_t)) h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) output = self.linear(h_t2) # output.shape:[batch,1] outputs = [output] # outputs.shape:[[batch,1],...[batch,1]], list composed of n [batch,1], for i in range(future): # if we should predict the future h_t, c_t = self.lstm1(output, (h_t, c_t)) h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) output = self.linear(h_t2) # output.shape:[batch,1] outputs = [output] # outputs.shape:[[batch,1],...[batch,1]], list composed of n [batch,1], outputs = torch.stack(outputs, 1).squeeze(2) # shape after stack:[batch, n, 1], shape after squeeze: [batch,n] return outputs
下载说明:请别用迅雷下载,失败请重下,重下不扣分!