mike-neckのブログ

Java or Groovy or Swift or Golang

Javaの機械学習ライブラリーでFizzBuzzしてみた

FizzBuzzをするのにTensorFlowというネタツイートがあったので、「機械学習 FizzBuzz」で検索したらすでにいくつかあったので、真似してみたくなった。ちなみに機械学習のことは全然わかってません。

qiita.com

blog.amedama.jp

github.com


ライブラリー/フレームワーク

EncogというJava機械学習ライブラリーを使いました(TensorFlow使いたかった…)。

EncogはJeff Heatonという方が開発している機械学習ライブラリーで、サポートベクターマシンニューラルネットワークベイジアンネットワークなどいくつかのアルゴリズムをサポートしているJava/C#フレームワークだそうです。ドキュメントもしっかりしているので、初心者でも使いやすいライブラリーとなっています。

github.com


導入

dependencies {
  compile 'org.encog:encog-core:3.3.0'
}

学習データ

201〜1000を学習データとして用いました。

特徴量は「3で割り切れた場合1.0、割り切れなかった場合0.0」「5で割り切れた場合1.0、割り切れなかった場合0.0」という二つの量を設定しました。

//入力の数値
data class InputNumber(
        val num: Int, val fizz: Boolean, val buzz: Boolean)

//Intから入力の数値に変換する
val toInput: (Int) -> InputNumber = {
    val fizz = it % 3 == 0
    val buzz = it % 5 == 0
    InputNumber(it, fizz, buzz)
}

Qiitaの記事のSVMFizzBuzzと同じように、FizzBuzzのフィルター、Buzzのフィルター、Fizzのフィルターを用意し、入力された数値がそれぞれのフィルターに適合する場合は1.0、適合しない場合0.0となるような教師データを用意します。たとえばFizz用のフィルターには次のような学習データを与えてトレーニングします。

数値 特徴量 出力
201 [1.0, 0.0] 1.0
202 [0.0, 0.0] 0.0
203 [0.0, 0.0] 0.0
204 [1.0, 0.0] 1.0
205 [0.0, 1.0] 0.0
//Fizz/Buzzをあらわす
enum class FizzBuzz {
    FIZZ,
    BUZZ,
    FIZZ_BUZZ,
    NONE
}

//入力値のFizzBuzz結果
fun InputNumber.toFizzBuzz(): FizzBuzz {
    return when(Pair(fizz, buzz)) {
        Pair(true, true)  -> FizzBuzz.FIZZ_BUZZ
        Pair(false, true) -> FizzBuzz.BUZZ
        Pair(true, false) -> FizzBuzz.FIZZ
        else              -> FizzBuzz.NONE
    }
}

//FizzBuzzの教師データ変換
fun FizzBuzz.ideal(fizzBuzz: FizzBuzz): DoubleArray =
        if (this == fizzBuzz) doubleArrayOf(1.0)
        else doubleArrayOf(0.0)

//学習データ
data class FizzBuzzStudy(
        val input: InputNumber,
        val fizzBuzz: FizzBuzz)

//学習データの入力値を取り出す
fun FizzBuzzStudy.inputArray(): DoubleArray =
        booleanArrayOf(input.fizz, input.buzz).toDoubleArray()

//学習データの教師データを取り出す
fun FizzBuzzStudy.idealDataArray(): DoubleArray =
        input.toFizzBuzz().ideal(fizzBuzz)

学習の実行

今回はBasicNetworkというネットワークを使います。これはEncogのQuick Start GuideにあるXORの学習で使われているものです。

学習エラーが0.01以下になるまで繰り返し学習を行います。

//dataは201〜1000までの学習データを保持するオブジェクト
//fizzBuzzはフィルターをあらわすFizzBuzz(FizzBuzz専用フィルターかFizz専用フィルターか)
class Filter(val data: MLDataSet, val fizzBuzz: FizzBuzz) {

    val network = BasicNetwork()

    init {
        network.addLayer(BasicLayer(null, true, 2))
        network.addLayer(BasicLayer(ActivationSigmoid(), true, 4))
        network.addLayer(BasicLayer(ActivationSigmoid(), false, 1))
        network.structure.finalizeStructure()
        network.reset()

        val training = ResilientPropagation(network, data)
        for (t in 1..30) {
            training.iteration()
            if (training.error < 0.01) break
        }
        training.finishTraining()
    }

    //etc...
}

フィルター

学習が終わると、Filternetworkというオブジェクトはcomputeというメソッドで、入力値(特徴量)が適合する場合に1に近い数値を返すようになります。

これをFizzBuzzの場合、Buzzの場合、Fizzの場合の3回おこなって、FizzBuzzの判断をおこないます。

class Filter(val data: MLDataSet, val fizzBuzz: FizzBuzz) {
    //一部省略

    //入力値から値を予測する
    private fun compute(input: MLData): MLData = network.compute(input)

    //入力値が適合するかチェックする
    fun match(input: InputNumber): Boolean =
        compute(BasicMLData(input.toDataArray())).getData(0).round() == 1
}

//数値をFizzBuzzに変換する
fun InputNumber.fizzBuzz(filters: List<Filter>): String {
    val found = filters.find { it.match(this) }
    return when(found?.fizzBuzz) {
        FizzBuzz.FIZZ_BUZZ -> FizzBuzz.FIZZ_BUZZ.name
        FizzBuzz.BUZZ      -> FizzBuzz.BUZZ.name
        FizzBuzz.FIZZ      -> FizzBuzz.FIZZ.name
        else               -> "$num"
    }
}

プログラム実行

学習データは201〜1000、テストするデータは1〜40として機械学習によるFizzBuzzを実行してみます。

object StudyRange {
    val STUDY_START = 201
    val STUDY_END = 1000
}

fun main(args: Array<String>) {
    val inputList: List<InputNumber> =
            StudyRange.STUDY_START.rangeTo(StudyRange.STUDY_END).map(toInput)
    val studyDataList: List<StudyData> = fizzBuzzList().map { StudyData(inputList, it) }
    val filters: List<Filter> = studyDataList.map { Filter(it.dataSet(), it.fizzBuzz) }

    val testData: List<InputNumber> = (1..40).map(toInput)
    testData.map {
        "${it.num} -> ${it.fizzBuzz(filters)} (correct: ${it.correct()})"
    }.forEach{
        println(it)
    }
    Encog.getInstance().shutdown()
}

実行結果

1 -> 1 (correct: 1)
2 -> 2 (correct: 2)
3 -> FIZZ (correct: FIZZ)
4 -> 4 (correct: 4)
5 -> BUZZ (correct: BUZZ)
6 -> FIZZ (correct: FIZZ)
7 -> 7 (correct: 7)
8 -> 8 (correct: 8)
9 -> FIZZ (correct: FIZZ)
10 -> BUZZ (correct: BUZZ)
11 -> 11 (correct: 11)
12 -> FIZZ (correct: FIZZ)
13 -> 13 (correct: 13)
14 -> 14 (correct: 14)
15 -> FIZZ_BUZZ (correct: FIZZ_BUZZ)
16 -> 16 (correct: 16)
17 -> 17 (correct: 17)
18 -> FIZZ (correct: FIZZ)
19 -> 19 (correct: 19)
20 -> BUZZ (correct: BUZZ)
21 -> FIZZ (correct: FIZZ)
22 -> 22 (correct: 22)
23 -> 23 (correct: 23)
24 -> FIZZ (correct: FIZZ)
25 -> BUZZ (correct: BUZZ)
26 -> 26 (correct: 26)
27 -> FIZZ (correct: FIZZ)
28 -> 28 (correct: 28)
29 -> 29 (correct: 29)
30 -> FIZZ_BUZZ (correct: FIZZ_BUZZ)
31 -> 31 (correct: 31)
32 -> 32 (correct: 32)
33 -> FIZZ (correct: FIZZ)
34 -> 34 (correct: 34)
35 -> BUZZ (correct: BUZZ)
36 -> FIZZ (correct: FIZZ)
37 -> 37 (correct: 37)
38 -> 38 (correct: 38)
39 -> FIZZ (correct: FIZZ)
40 -> BUZZ (correct: BUZZ)

どうやらうまくいったようです。

全体のコードはgistにあります。

機械学習でFizzBuzz · GitHub