mike-neckのブログ

Java or Groovy or Swift or Golang

JUnit Platform Engine の作り方

JUnit Jupiter でのテストクラスの書き方は、世の中にたくさん、このブログよりも SEO に長けているブログの記事が見つかるので、それを参照されるがよい。

f:id:mike_neck:20200316210134p:plain

そうではなくて、ここでは JUnit Platform Engine の書き方、作り方の手引を書く。

ちなみに、僕は Kotlin 用のテストフレームワークを作りたくなったので、 Junit Platform Engine の使い方を調べていた。念の為、それについて書かれているブログ記事を探してみたが、あまり教養がないので見つけられなかった。多分、公式ドキュメントの 「6.1.3. Plugging in your own Test Engine」 が唯一まとまった文章だが、たかだか 20 行弱しかない。

この記事はこれからテストフレームワークを作るというマニアックな人のための導入になるかもしれない。


作るもの

単純に以下の2つを作る

  • TestEngine の実装クラスを一つ
  • TestDescriptor の実装クラスを1つ、インスタンスはテストメソッド・クラス・エンジンそれぞれについて作る

TestEngine

エンジンのエントリーポイント。 JUnit Launcher は ServiceLoader の仕組みを使って、エンジンのインスタンスを取得して起動する

  • discover - これによりテストを表すクラス・メソッドなどをスキャンする。スキャンした内容を TestDescriptor に入れて返す
  • getId - エンジンを表す ID。エンジンの ID として UniqueId がここから作られる
  • execute - テストを実行する。パラメーターの ExecutionRequest から discover でスキャンしたテストの集合が得られるので、それを一つずつ実行する。 実行の開始、終了は ExecutionRequest から取得できる EngineExecutionListener に都度報告する

TestDescriptor

テスト、あるいはテストを内包するコンテナクラス、あるいはエンジンのメタ情報をもつオブジェクト。以下のヒエラルキーを持つ。

  1. エンジンディスクリプター - テストエンジンを表す。 children に 2. を持つ/ getSource は必ず Optional#empty を返す
  2. コンテナディスクリプター - テストケースクラスを表す。 children に 3. を持つ/ getSourceClassSource を返す
  3. メソッドディスクリプター - テストメソッド、つまり個々のテストを表す/ getSourceMethodSource を返す

次のような雑なエンジン(スキャンすらしない)を実行すると、末尾にあるようなレポートが出力される

import org.junit.platform.engine.EngineDiscoveryRequest
import org.junit.platform.engine.ExecutionRequest
import org.junit.platform.engine.TestDescriptor
import org.junit.platform.engine.TestEngine
import org.junit.platform.engine.TestExecutionResult
import org.junit.platform.engine.TestSource
import org.junit.platform.engine.TestTag
import org.junit.platform.engine.UniqueId
import org.junit.platform.engine.reporting.ReportEntry
import org.junit.platform.engine.support.descriptor.ClassSource
import org.junit.platform.engine.support.descriptor.MethodSource
import org.junit.platform.engine.support.hierarchical.EngineExecutionContext
import org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.*

class KotlinTestEngine : TestEngine {

  override fun discover(discoveryRequest: EngineDiscoveryRequest?, uniqueId: UniqueId?): TestDescriptor {
    if (uniqueId == null) {
      throw IllegalArgumentException("uniqueId is null")
    }
    if (discoveryRequest == null) {
      throw IllegalArgumentException("engine discovery request is null")
    }
    logger.info("discover request, id: {}, request: {}", uniqueId, discoveryRequest.configurationParameters)
    return KotlinTestDescriptor
  }

  override fun getId(): String = "k-check"

  override fun execute(request: ExecutionRequest?) {
    if (request == null) {
      throw IllegalArgumentException("execution request is null")
    }
    val listener = request.engineExecutionListener

    logger.info("request root desc {}", request.rootTestDescriptor)
    logger.info("execution start {} param: {}", KotlinTestDescriptor, request.configurationParameters)
    listener.executionStarted(KotlinTestDescriptor)
    logger.info("execution start {} param: {}", FooDescriptor, request.configurationParameters)
    listener.executionStarted(FooDescriptor)
    logger.info("publish entry foo")
    listener.reportingEntryPublished(FooDescriptor, ReportEntry.from("foo", "FOO"))
    listener.executionStarted(BarDescriptor)
    listener.reportingEntryPublished(BarDescriptor, ReportEntry.from("test", "bar"))
    logger.info("execution finish {}", BarDescriptor::class.simpleName)
    listener.executionFinished(BarDescriptor, TestExecutionResult.successful())
    logger.info("execution start {}", BazDescriptor::class.simpleName)
    listener.executionStarted(BazDescriptor)
    logger.info("execution finish {}", BazDescriptor::class.simpleName)
    listener.executionFinished(BazDescriptor, TestExecutionResult.failed(Impediments.failed("baz fail", "1", 200)))
    logger.info("execution finish {}", FooDescriptor::class.simpleName)
    listener.executionFinished(FooDescriptor, TestExecutionResult.successful())
    logger.info("execution finish {}", KotlinTestDescriptor::class.simpleName)
    listener.executionFinished(KotlinTestDescriptor, TestExecutionResult.successful())
  }

  companion object {
    val logger: Logger = LoggerFactory.getLogger(KotlinTestEngine::class.java)
  }
}

object KotlinTestDescriptor : TestDescriptor {

  private val logger: Logger = LoggerFactory.getLogger(KotlinTestDescriptor::class.java)

  override fun getSource(): Optional<TestSource> = Optional.empty()
  override fun removeFromHierarchy() = logger.info("remove from hierarchy call")
  override fun setParent(parent: TestDescriptor?) = logger.info("set parent call {}", parent)
  override fun getParent(): Optional<TestDescriptor> = Optional.empty<TestDescriptor>().also { logger.info("get parent call") }
  override fun getChildren(): MutableSet<out TestDescriptor> = mutableSetOf(FooDescriptor).also { logger.info("get children call") }
  override fun getDisplayName(): String = "k-check".also { logger.info("get display name call, {}", it) }
  override fun getType(): TestDescriptor.Type = TestDescriptor.Type.CONTAINER
  override fun getUniqueId(): UniqueId = UniqueId.forEngine("k-check")
  override fun removeChild(descriptor: TestDescriptor?) = logger.info("remove child call id:{}", descriptor?.uniqueId)
  override fun addChild(descriptor: TestDescriptor?) = logger.info("add child call id: {}", descriptor?.uniqueId)
  override fun findByUniqueId(uniqueId: UniqueId?): Optional<out TestDescriptor> = logger.info("find by unique id call, id: {}", uniqueId).let {
    when (uniqueId) {
      UniqueId.forEngine("k-check").append("test", "foo") -> Optional.of(FooDescriptor)
      else -> Optional.empty()
    }
  }

  override fun getTags(): MutableSet<TestTag> = mutableSetOf()
}

object FooDescriptor : TestDescriptor {
  private val logger: Logger = LoggerFactory.getLogger(FooDescriptor::class.java)

  override fun getSource(): Optional<TestSource> =
      Optional.of(ClassSource.from(FooDescriptor::class.java))

  override fun removeFromHierarchy() = Unit
  override fun setParent(parent: TestDescriptor?) = Unit.also { logger.info("set parent call, id: {}", parent?.uniqueId) }
  override fun getParent(): Optional<TestDescriptor> = Optional.of<TestDescriptor>(KotlinTestDescriptor).also { logger.info("get parent call") }
  override fun getChildren(): MutableSet<out TestDescriptor> = mutableSetOf(BarDescriptor, BazDescriptor).also { logger.info("get children call") }
  override fun getDisplayName(): String = "${FooDescriptor::class.simpleName}".also { logger.info("get display name call, {}", it) }
  override fun getType(): TestDescriptor.Type = TestDescriptor.Type.CONTAINER_AND_TEST
  override fun getUniqueId(): UniqueId = UniqueId.forEngine("k-check").append("test", "foo")
  override fun removeChild(descriptor: TestDescriptor?) = Unit
  override fun addChild(descriptor: TestDescriptor?) = Unit
  override fun findByUniqueId(uniqueId: UniqueId?): Optional<out TestDescriptor> = Optional.empty()
  override fun getTags(): MutableSet<TestTag> = mutableSetOf()
}

object BarDescriptor : TestDescriptor {
  private val logger: Logger = LoggerFactory.getLogger(BarDescriptor::class.java)
  override fun getSource(): Optional<TestSource> = Optional.of(MethodSource.from(BarDescriptor::class.java.simpleName, "test"))
  override fun removeFromHierarchy() = logger.info("remove from hierarchy call")
  override fun setParent(parent: TestDescriptor?) = logger.info("set parent call {}", parent)
  override fun getParent(): Optional<TestDescriptor> = logger.info("get parent call").let { Optional.of(FooDescriptor) }
  override fun getChildren(): MutableSet<out TestDescriptor> = logger.info("get children call").let { mutableSetOf() }
  override fun getDisplayName(): String = "${this::class.simpleName}"
  override fun getType(): TestDescriptor.Type = TestDescriptor.Type.TEST
  override fun getUniqueId(): UniqueId = UniqueId.forEngine("k-check").append("test", "foo").append("test", "bar")
  override fun removeChild(descriptor: TestDescriptor?) = logger.info("remove child call, {}", descriptor)
  override fun addChild(descriptor: TestDescriptor?) = logger.info("add child call, {}", descriptor)
  override fun findByUniqueId(uniqueId: UniqueId?): Optional<out TestDescriptor> = logger.info("find by unique id call, {}", uniqueId).let { Optional.empty() }
  override fun getTags(): MutableSet<TestTag> = logger.info("get tags call").let { mutableSetOf(TestTag.create("foo"), TestTag.create("bar")) }
}

object BazDescriptor : TestDescriptor {
  private val logger: Logger = LoggerFactory.getLogger(BazDescriptor::class.java)
  override fun getSource(): Optional<TestSource> = Optional.of(MethodSource.from(BazDescriptor::class.java.simpleName, "test"))
  override fun removeFromHierarchy() = logger.info("remove from hierarchy call")
  override fun setParent(parent: TestDescriptor?) = logger.info("set parent call {}", parent)
  override fun getParent(): Optional<TestDescriptor> = logger.info("get parent call").let { Optional.of(FooDescriptor) }
  override fun getChildren(): MutableSet<out TestDescriptor> = logger.info("get children call").let { mutableSetOf() }
  override fun getDisplayName(): String = "${this::class.simpleName}"
  override fun getType(): TestDescriptor.Type = TestDescriptor.Type.TEST
  override fun getUniqueId(): UniqueId = UniqueId.forEngine("k-check").append("test", "foo").append("test", "baz")
  override fun removeChild(descriptor: TestDescriptor?) = logger.info("remove child call, {}", descriptor)
  override fun addChild(descriptor: TestDescriptor?) = logger.info("add child call, {}", descriptor)
  override fun findByUniqueId(uniqueId: UniqueId?): Optional<out TestDescriptor> = logger.info("find by unique id call, {}", uniqueId).let { Optional.empty() }
  override fun getTags(): MutableSet<TestTag> = logger.info("get tags call").let { mutableSetOf(TestTag.create("foo"), TestTag.create("baz")) }
}

f:id:mike_neck:20200316211725p:plain

f:id:mike_neck:20200316211811p:plain