FizzBuzzをするのにTensorFlowというネタツイートがあったので、「機械学習 FizzBuzz」で検索したらすでにいくつかあったので、真似してみたくなった。ちなみに機械学習のことは全然わかってません。
IT会社面接官:「数字を列挙し、3の倍数ならfizz、5の倍数ならbuzz、15の倍数ならfizzbuzzを出力するプログラムを書いてください。」
— Graham Neubig (@neubig) 2016年5月24日
面接を受けている人:「では、まずTensorFlowをインポートします…」https://t.co/UNk2jH5rfn
ライブラリー/フレームワーク
EncogというJavaの機械学習ライブラリーを使いました(TensorFlow使いたかった…)。
EncogはJeff Heatonという方が開発している機械学習ライブラリーで、サポートベクターマシン、ニューラルネットワーク、ベイジアンネットワークなどいくつかのアルゴリズムをサポートしているJava/C#用フレームワークだそうです。ドキュメントもしっかりしているので、初心者でも使いやすいライブラリーとなっています。
導入
dependencies {
compile 'org.encog:encog-core:3.3.0'
}
学習データ
201〜1000を学習データとして用いました。
特徴量は「3で割り切れた場合1.0、割り切れなかった場合0.0」「5で割り切れた場合1.0、割り切れなかった場合0.0」という二つの量を設定しました。
//入力の数値 data class InputNumber( val num: Int, val fizz: Boolean, val buzz: Boolean) //Intから入力の数値に変換する val toInput: (Int) -> InputNumber = { val fizz = it % 3 == 0 val buzz = it % 5 == 0 InputNumber(it, fizz, buzz) }
Qiitaの記事のSVMでFizzBuzzと同じように、FizzBuzzのフィルター、Buzzのフィルター、Fizzのフィルターを用意し、入力された数値がそれぞれのフィルターに適合する場合は1.0、適合しない場合0.0となるような教師データを用意します。たとえばFizz用のフィルターには次のような学習データを与えてトレーニングします。
数値 | 特徴量 | 出力 |
---|---|---|
201 | [1.0, 0.0] |
1.0 |
202 | [0.0, 0.0] |
0.0 |
203 | [0.0, 0.0] |
0.0 |
204 | [1.0, 0.0] |
1.0 |
205 | [0.0, 1.0] |
0.0 |
//Fizz/Buzzをあらわす enum class FizzBuzz { FIZZ, BUZZ, FIZZ_BUZZ, NONE } //入力値のFizzBuzz結果 fun InputNumber.toFizzBuzz(): FizzBuzz { return when(Pair(fizz, buzz)) { Pair(true, true) -> FizzBuzz.FIZZ_BUZZ Pair(false, true) -> FizzBuzz.BUZZ Pair(true, false) -> FizzBuzz.FIZZ else -> FizzBuzz.NONE } } //FizzBuzzの教師データ変換 fun FizzBuzz.ideal(fizzBuzz: FizzBuzz): DoubleArray = if (this == fizzBuzz) doubleArrayOf(1.0) else doubleArrayOf(0.0) //学習データ data class FizzBuzzStudy( val input: InputNumber, val fizzBuzz: FizzBuzz) //学習データの入力値を取り出す fun FizzBuzzStudy.inputArray(): DoubleArray = booleanArrayOf(input.fizz, input.buzz).toDoubleArray() //学習データの教師データを取り出す fun FizzBuzzStudy.idealDataArray(): DoubleArray = input.toFizzBuzz().ideal(fizzBuzz)
学習の実行
今回はBasicNetworkというネットワークを使います。これはEncogのQuick Start GuideにあるXORの学習で使われているものです。
学習エラーが0.01以下になるまで繰り返し学習を行います。
//dataは201〜1000までの学習データを保持するオブジェクト //fizzBuzzはフィルターをあらわすFizzBuzz(FizzBuzz専用フィルターかFizz専用フィルターか) class Filter(val data: MLDataSet, val fizzBuzz: FizzBuzz) { val network = BasicNetwork() init { network.addLayer(BasicLayer(null, true, 2)) network.addLayer(BasicLayer(ActivationSigmoid(), true, 4)) network.addLayer(BasicLayer(ActivationSigmoid(), false, 1)) network.structure.finalizeStructure() network.reset() val training = ResilientPropagation(network, data) for (t in 1..30) { training.iteration() if (training.error < 0.01) break } training.finishTraining() } //etc... }
フィルター
学習が終わると、Filter
のnetwork
というオブジェクトはcompute
というメソッドで、入力値(特徴量)が適合する場合に1に近い数値を返すようになります。
これをFizzBuzzの場合、Buzzの場合、Fizzの場合の3回おこなって、FizzBuzzの判断をおこないます。
class Filter(val data: MLDataSet, val fizzBuzz: FizzBuzz) { //一部省略 //入力値から値を予測する private fun compute(input: MLData): MLData = network.compute(input) //入力値が適合するかチェックする fun match(input: InputNumber): Boolean = compute(BasicMLData(input.toDataArray())).getData(0).round() == 1 } //数値をFizzBuzzに変換する fun InputNumber.fizzBuzz(filters: List<Filter>): String { val found = filters.find { it.match(this) } return when(found?.fizzBuzz) { FizzBuzz.FIZZ_BUZZ -> FizzBuzz.FIZZ_BUZZ.name FizzBuzz.BUZZ -> FizzBuzz.BUZZ.name FizzBuzz.FIZZ -> FizzBuzz.FIZZ.name else -> "$num" } }
プログラム実行
学習データは201〜1000、テストするデータは1〜40として機械学習によるFizzBuzzを実行してみます。
object StudyRange { val STUDY_START = 201 val STUDY_END = 1000 } fun main(args: Array<String>) { val inputList: List<InputNumber> = StudyRange.STUDY_START.rangeTo(StudyRange.STUDY_END).map(toInput) val studyDataList: List<StudyData> = fizzBuzzList().map { StudyData(inputList, it) } val filters: List<Filter> = studyDataList.map { Filter(it.dataSet(), it.fizzBuzz) } val testData: List<InputNumber> = (1..40).map(toInput) testData.map { "${it.num} -> ${it.fizzBuzz(filters)} (correct: ${it.correct()})" }.forEach{ println(it) } Encog.getInstance().shutdown() }
実行結果
1 -> 1 (correct: 1) 2 -> 2 (correct: 2) 3 -> FIZZ (correct: FIZZ) 4 -> 4 (correct: 4) 5 -> BUZZ (correct: BUZZ) 6 -> FIZZ (correct: FIZZ) 7 -> 7 (correct: 7) 8 -> 8 (correct: 8) 9 -> FIZZ (correct: FIZZ) 10 -> BUZZ (correct: BUZZ) 11 -> 11 (correct: 11) 12 -> FIZZ (correct: FIZZ) 13 -> 13 (correct: 13) 14 -> 14 (correct: 14) 15 -> FIZZ_BUZZ (correct: FIZZ_BUZZ) 16 -> 16 (correct: 16) 17 -> 17 (correct: 17) 18 -> FIZZ (correct: FIZZ) 19 -> 19 (correct: 19) 20 -> BUZZ (correct: BUZZ) 21 -> FIZZ (correct: FIZZ) 22 -> 22 (correct: 22) 23 -> 23 (correct: 23) 24 -> FIZZ (correct: FIZZ) 25 -> BUZZ (correct: BUZZ) 26 -> 26 (correct: 26) 27 -> FIZZ (correct: FIZZ) 28 -> 28 (correct: 28) 29 -> 29 (correct: 29) 30 -> FIZZ_BUZZ (correct: FIZZ_BUZZ) 31 -> 31 (correct: 31) 32 -> 32 (correct: 32) 33 -> FIZZ (correct: FIZZ) 34 -> 34 (correct: 34) 35 -> BUZZ (correct: BUZZ) 36 -> FIZZ (correct: FIZZ) 37 -> 37 (correct: 37) 38 -> 38 (correct: 38) 39 -> FIZZ (correct: FIZZ) 40 -> BUZZ (correct: BUZZ)
どうやらうまくいったようです。
全体のコードはgistにあります。