FizzBuzzをするのにTensorFlowというネタツイートがあったので、「機械学習 FizzBuzz」で検索したらすでにいくつかあったので、真似してみたくなった。ちなみに機械学習のことは全然わかってません。
qiita.com
blog.amedama.jp
github.com
EncogというJavaの機械学習ライブラリーを使いました(TensorFlow使いたかった…)。
EncogはJeff Heatonという方が開発している機械学習ライブラリーで、サポートベクターマシン、ニューラルネットワーク、ベイジアンネットワークなどいくつかのアルゴリズムをサポートしているJava/C#用フレームワークだそうです。ドキュメントもしっかりしているので、初心者でも使いやすいライブラリーとなっています。
github.com
導入
dependencies {
compile 'org.encog:encog-core:3.3.0'
}
学習データ
201〜1000を学習データとして用いました。
特徴量は「3で割り切れた場合1.0、割り切れなかった場合0.0」「5で割り切れた場合1.0、割り切れなかった場合0.0」という二つの量を設定しました。
data class InputNumber(
val num: Int, val fizz: Boolean, val buzz: Boolean)
val toInput: (Int) -> InputNumber = {
val fizz = it % 3 == 0
val buzz = it % 5 == 0
InputNumber(it, fizz, buzz)
}
Qiitaの記事のSVMでFizzBuzzと同じように、FizzBuzzのフィルター、Buzzのフィルター、Fizzのフィルターを用意し、入力された数値がそれぞれのフィルターに適合する場合は1.0、適合しない場合0.0となるような教師データを用意します。たとえばFizz用のフィルターには次のような学習データを与えてトレーニングします。
数値 |
特徴量 |
出力 |
201 |
[1.0, 0.0] |
1.0 |
202 |
[0.0, 0.0] |
0.0 |
203 |
[0.0, 0.0] |
0.0 |
204 |
[1.0, 0.0] |
1.0 |
205 |
[0.0, 1.0] |
0.0 |
enum class FizzBuzz {
FIZZ,
BUZZ,
FIZZ_BUZZ,
NONE
}
fun InputNumber.toFizzBuzz(): FizzBuzz {
return when(Pair(fizz, buzz)) {
Pair(true, true) -> FizzBuzz.FIZZ_BUZZ
Pair(false, true) -> FizzBuzz.BUZZ
Pair(true, false) -> FizzBuzz.FIZZ
else -> FizzBuzz.NONE
}
}
fun FizzBuzz.ideal(fizzBuzz: FizzBuzz): DoubleArray =
if (this == fizzBuzz) doubleArrayOf(1.0)
else doubleArrayOf(0.0)
data class FizzBuzzStudy(
val input: InputNumber,
val fizzBuzz: FizzBuzz)
fun FizzBuzzStudy.inputArray(): DoubleArray =
booleanArrayOf(input.fizz, input.buzz).toDoubleArray()
fun FizzBuzzStudy.idealDataArray(): DoubleArray =
input.toFizzBuzz().ideal(fizzBuzz)
学習の実行
今回はBasicNetworkというネットワークを使います。これはEncogのQuick Start GuideにあるXORの学習で使われているものです。
学習エラーが0.01以下になるまで繰り返し学習を行います。
class Filter(val data: MLDataSet, val fizzBuzz: FizzBuzz) {
val network = BasicNetwork()
init {
network.addLayer(BasicLayer(null, true, 2))
network.addLayer(BasicLayer(ActivationSigmoid(), true, 4))
network.addLayer(BasicLayer(ActivationSigmoid(), false, 1))
network.structure.finalizeStructure()
network.reset()
val training = ResilientPropagation(network, data)
for (t in 1..30) {
training.iteration()
if (training.error < 0.01) break
}
training.finishTraining()
}
}
フィルター
学習が終わると、Filter
のnetwork
というオブジェクトはcompute
というメソッドで、入力値(特徴量)が適合する場合に1に近い数値を返すようになります。
これをFizzBuzzの場合、Buzzの場合、Fizzの場合の3回おこなって、FizzBuzzの判断をおこないます。
class Filter(val data: MLDataSet, val fizzBuzz: FizzBuzz) {
private fun compute(input: MLData): MLData = network.compute(input)
fun match(input: InputNumber): Boolean =
compute(BasicMLData(input.toDataArray())).getData(0).round() == 1
}
fun InputNumber.fizzBuzz(filters: List<Filter>): String {
val found = filters.find { it.match(this) }
return when(found?.fizzBuzz) {
FizzBuzz.FIZZ_BUZZ -> FizzBuzz.FIZZ_BUZZ.name
FizzBuzz.BUZZ -> FizzBuzz.BUZZ.name
FizzBuzz.FIZZ -> FizzBuzz.FIZZ.name
else -> "$num"
}
}
プログラム実行
学習データは201〜1000、テストするデータは1〜40として機械学習によるFizzBuzzを実行してみます。
object StudyRange {
val STUDY_START = 201
val STUDY_END = 1000
}
fun main(args: Array<String>) {
val inputList: List<InputNumber> =
StudyRange.STUDY_START.rangeTo(StudyRange.STUDY_END).map(toInput)
val studyDataList: List<StudyData> = fizzBuzzList().map { StudyData(inputList, it) }
val filters: List<Filter> = studyDataList.map { Filter(it.dataSet(), it.fizzBuzz) }
val testData: List<InputNumber> = (1..40).map(toInput)
testData.map {
"${it.num} -> ${it.fizzBuzz(filters)} (correct: ${it.correct()})"
}.forEach{
println(it)
}
Encog.getInstance().shutdown()
}
実行結果
1 -> 1 (correct: 1)
2 -> 2 (correct: 2)
3 -> FIZZ (correct: FIZZ)
4 -> 4 (correct: 4)
5 -> BUZZ (correct: BUZZ)
6 -> FIZZ (correct: FIZZ)
7 -> 7 (correct: 7)
8 -> 8 (correct: 8)
9 -> FIZZ (correct: FIZZ)
10 -> BUZZ (correct: BUZZ)
11 -> 11 (correct: 11)
12 -> FIZZ (correct: FIZZ)
13 -> 13 (correct: 13)
14 -> 14 (correct: 14)
15 -> FIZZ_BUZZ (correct: FIZZ_BUZZ)
16 -> 16 (correct: 16)
17 -> 17 (correct: 17)
18 -> FIZZ (correct: FIZZ)
19 -> 19 (correct: 19)
20 -> BUZZ (correct: BUZZ)
21 -> FIZZ (correct: FIZZ)
22 -> 22 (correct: 22)
23 -> 23 (correct: 23)
24 -> FIZZ (correct: FIZZ)
25 -> BUZZ (correct: BUZZ)
26 -> 26 (correct: 26)
27 -> FIZZ (correct: FIZZ)
28 -> 28 (correct: 28)
29 -> 29 (correct: 29)
30 -> FIZZ_BUZZ (correct: FIZZ_BUZZ)
31 -> 31 (correct: 31)
32 -> 32 (correct: 32)
33 -> FIZZ (correct: FIZZ)
34 -> 34 (correct: 34)
35 -> BUZZ (correct: BUZZ)
36 -> FIZZ (correct: FIZZ)
37 -> 37 (correct: 37)
38 -> 38 (correct: 38)
39 -> FIZZ (correct: FIZZ)
40 -> BUZZ (correct: BUZZ)
どうやらうまくいったようです。
全体のコードはgistにあります。
機械学習でFizzBuzz · GitHub