这篇文章将介绍一种直截了当的方法,可以估计与最先进的手动方法接近的参数。
我们将使用贝叶斯优化方法(Mango)在短短200次迭代中从108,000个可能选项中搜索最佳参数。
ARIMA时间序列预测模型非常适合具有趋势和季节性的序列。这是一个被广泛采用的经典模型,通常作为基准现代深度学习方法的基线。然而,估计其准确参数具有挑战性。研究人员和开发人员通常使用包括视觉绘图在内的试错方法。
ARIMA模型是什么?
ARIMA模型是“自动递归移动平均线”的缩写,是一类使用过去值来估计未来预测的模型。ARIMA模型由三个参数定义:p、d和q。
ARIMA模型在文献中研究了不同的变体。在这篇文章中,我们将使用statsmodels库中的实现。
整个笔记本显示一个简单的实现在这里可用。您可以为您的数据集修改此实现。根据需要创建单独的火车测试拆分。我简化了概述重要的调音步骤。
完整代码:使用芒果自动调音
import pandas as pd df = pd.read_csv('https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv') from statsmodels.tsa.arima.model import ARIMA from sklearn.metrics import mean_squared_error from mango import scheduler, Tuner def arima_objective_function(args_list): global data_values params_evaluated = [] results = [] for params in args_list: try: p,d,q = params['p'],params['d'], params['q'] trend = params['trend'] model = ARIMA(data_values, order=(p,d,q), trend = trend) predictions = model.fit() mse = mean_squared_error(data_values, predictions.fittedvalues) params_evaluated.append(params) results.append(mse) except: #print(f"Exception raised for {params}") #pass params_evaluated.append(params) results.append(1e5) #print(params_evaluated, mse) return params_evaluated, results param_space = dict(p= range(0, 30), d= range(0, 30), q =range(0, 30), trend = ['n', 'c', 't', 'ct'] ) conf_Dict = dict() conf_Dict['num_iteration'] = 200 data_values = list(df['#Passengers']) tuner = Tuner(param_space, arima_objective_function, conf_Dict) results = tuner.minimize() print('best parameters:', results['best_params']) print('best loss:', results['best_objective']) best parameters: {'d': 0, 'p': 17, 'q': 23, 'trend': 'ct'} best loss: 112.06886739549542
调音步骤
数据集:我们将使用一个简单的空中乘客数据集,记录航空公司乘客人数。
import pandas as pd df = pd.read_csv('https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv') df.head()
绘制系列图,以了解趋势和季节性
from matplotlib import pyplot as plt f = plt.figure() f.set_figwidth(15) f.set_figheight(6) plt.plot(df['#Passengers'], linewidth = 4, label = "original Series") plt.legend(fontsize=25) plt.xlabel('Months', fontsize = 25) plt.ylabel('Count', fontsize = 25) plt.show()
该数据集呈上升趋势,季节性为12个月。
传统上,一种方法可以使用领域知识从原始序列中去除趋势和季节性,然后使用剩余序列来预测未来。然而,我们将研究一种更直接的自动化方法。
如何自动调整参数?
我们将使用一个名为Mango的最先进的优化库来为我们的数据集找到最佳参数。让我们首先定义参数的范围。在这种优化方法中,我们定义了可能的参数范围。这个范围可能非常大,不需要精确。这些参数是从statsmodels库中定义的。
param_space = dict(p= range(0, 30), d= range(0, 30), q =range(0, 30), trend = ['n', 'c', 't', 'ct'] )
参数空间是使用python构造定义的:范围和列表。参数总可能组合的集合是30*30*30*4 = 108,000。因此,详尽的网格搜索非常耗时。我们将使用贝叶斯搜索优化器方法,在大约100次迭代内自动进行搜索。注意:根据您的数据集,范围的大小及其搜索空间可能会有所不同。定义一个大的搜索空间很好;让优化器为你做艰苦的工作。
训练ARIMA模型
要使用Mango,我们必须定义一个目标函数,该函数返回给定参数集的ARIMA模型错误。
from statsmodels.tsa.arima.model import ARIMA from sklearn.metrics import mean_squared_error from mango import scheduler, Tuner def arima_objective_function(args_list): global data_values params_evaluated = [] results = [] for params in args_list: try: p,d,q = params['p'],params['d'], params['q'] trend = params['trend'] model = ARIMA(data_values, order=(p,d,q), trend = trend) predictions = model.fit() mse = mean_squared_error(data_values, predictions.fittedvalues) params_evaluated.append(params) results.append(mse) except: #print(f"Exception raised for {params}") #pass params_evaluated.append(params) results.append(1e5) #print(params_evaluated, mse) return params_evaluated, results
我们从Mango库中获取参数,并返回参数及其结果。结果包括经过训练的ARIMA模型的错误。在这种情况下,错误是mean_squared_error。我们还包括try-catch语句,因为ARIMA模型可能不会对参数的每个组合/选择收敛。我们只返回模型工作的参数集。芒果内部优化使用这些参数,在很少的迭代中找到最佳模型(在本例中为100)。我们的目标是找到最小化错误函数的参数。
控制芒果迭代:配置参数。
来自芒果进口调度器,调谐器
from mango import scheduler, Tuner conf_Dict = dict() conf_Dict['num_iteration'] = 200 tuner = Tuner(param_space, arima_objective_function, conf_Dict)
可视化最佳模型预测
总的来说,我们看到总的可能参数组合非常大(108,000)。
def plot_arima(data_values, order = (1,1,1), trend = 'c'): print('final model:', order, trend) model = ARIMA(data_values, order=order, trend = trend) results = model.fit() error = mean_squared_error(data_values, results.fittedvalues) print('MSE error is:', error) from matplotlib import pyplot as plt f = plt.figure() f.set_figwidth(15) f.set_figheight(6) plt.plot(data_values, label = "original Series", linewidth = 4) plt.plot(results.fittedvalues, color='red', label = "Predictions", linestyle='dashed', linewidth = 3) plt.legend(fontsize = 25) plt.xlabel('Months', fontsize = 25) plt.ylabel('Count', fontsize = 25) plt.show() print(results['best_params']) order = (results['best_params']['p'], results['best_params']['d'], results['best_params']['q']) plot_arima(data_values, order=order, trend = results['best_params']['trend'])
标签:arima,statsmodels,python 来源:
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。