rabbit-dev

現役スマフォプログラマーが適当にプログラム関係の記事を放り込むブログ

機械学習を使って株価予測をしてみた

tensorflowやchainerなど最近何かと話題の人工知能機械学習。 今回はscikit-learnを使って株価予測をしてみた。以下の条件で予測してみると53%の確率になった。これは微妙。。。

条件

  • 学習期間は2016年の1年間
  • 予測期間は2017年の一部
  • 上がる下がるのみ判断する
  • 過去5日分の終値のデータを利用する

結果

予測結果
正解  : 47回
不正解 : 42回

コード

気になってる方の参考になればと思いソースコードを公開します。適当につかってください。問題起きても責任はとれませんのであしからず。

処理フロー
1. 学習用データから過去5日分のデータと、翌日株価が上がったか下がったかのデータを作成する。 2. scikit-learnの決定木を利用して1のデータを学習させる。
3. 予測用データから過去5日分のデータを作成し、決定木にデータを入力して予測を行う。
4. 予測が正しかったか判定する。

https://github.com/yuzoh/SampleStockPredict

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import csv
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier

res_ok = 0
res_ng = 0

#
# 学習データ作成
#
train_data = []
train_result = []

f = open('train.csv')
reader = csv.reader(f)
tmp_data = []
prev_data = 0.0
cur_data = 0.0

for item in reader:
    cur_data = float(item[0])

    # 過去5日分のデータを元に学習データを作成する
    if 5 <= len(tmp_data):
        # 過去5日分のデータ
        train_data.append(tmp_data[:])

        # 翌日上がったか、下がったかのデータ
        if prev_data < cur_data:
            train_result.append(1)
        else:
            train_result.append(0)

        tmp_data.pop(0)

    # データの更新
    tmp_data.append(cur_data)
    prev_data = cur_data

#
# 学習
#  
clf = tree.DecisionTreeClassifier()
#clf = RandomForestClassifier()
clf.fit(train_data, train_result)

#
# 予測
#
f = open('predict.csv')
reader = csv.reader(f)
tmp_data = []
prev_data = 0.0
cur_data = 0.0

for item in reader:
    cur_data = float(item[0])

    # 過去5日分のデータを元に翌日の予測を行う
    if 5 <= len(tmp_data):
        # 予測するためのデータ作成
        predict_data = np.array(tmp_data)
        predict_data = predict_data.reshape(1, -1)

        # 予測
        result = clf.predict(predict_data)

        # 実際に上がったのか下がったのか判定
        if prev_data < cur_data:
            res = 1
        else:
            res = 0

        # 予測の比較
        if result == res:
            res_ok = res_ok + 1
        else:
            res_ng = res_ng + 1

        tmp_data.pop(0)

    tmp_data.append(cur_data)
    prev_data = cur_data

print "予測結果"
print "正解  : %d回" % res_ok
print "不正解 : %d回" % res_ng

もっと予測結果良くしたい!

  • 翌日を予測するのではなく1ヶ月後などを予測してみる
  • 前日との変動率を学習させる
  • 日経平均以外のデータも学習させる

感想

tensorflowを使いたかったんだけど、シグモイド関数?とかがわからなかったので簡単そうなscikit-learn使いました。 かなりお手軽に機械学習できるんで、興味持ってる方は是非やってみてください。