diff --git a/src/test/scala/com/typesafe/config/impl/PublicApiTest.scala b/src/test/scala/com/typesafe/config/impl/PublicApiTest.scala index c4ea01aa..07f23720 100644 --- a/src/test/scala/com/typesafe/config/impl/PublicApiTest.scala +++ b/src/test/scala/com/typesafe/config/impl/PublicApiTest.scala @@ -7,6 +7,7 @@ import com.typesafe.config._ import java.util.Collections import java.util.TreeSet import java.io.File +import scala.collection.mutable class PublicApiTest extends TestUtils { @Test @@ -215,4 +216,68 @@ class PublicApiTest extends TestUtils { val conf2 = Config.parse(resource("test03.conf"), ConfigParseOptions.defaults().setAllowMissing(true)) assertEquals(conf, conf2) } + + case class Included(name: String, fallback: ConfigIncluder) + + class RecordingIncluder(val fallback: ConfigIncluder, val included: mutable.ListBuffer[Included]) extends ConfigIncluder { + override def include(context: ConfigIncludeContext, name: String): ConfigObject = { + included += Included(name, fallback) + fallback.include(context, name) + } + + override def withFallback(fallback: ConfigIncluder) = { + if (this.fallback == fallback) { + this; + } else if (this.fallback == null) { + new RecordingIncluder(fallback, included); + } else { + new RecordingIncluder(this.fallback.withFallback(fallback), included) + } + } + } + + private def whatWasIncluded(parser: ConfigParseOptions => ConfigObject): List[Included] = { + val included = mutable.ListBuffer[Included]() + val includer = new RecordingIncluder(null, included) + + val conf = parser(ConfigParseOptions.defaults().setIncluder(includer).setAllowMissing(false)) + + included.toList + } + + @Test + def includersAreUsedWithFiles() { + val included = whatWasIncluded(Config.parse(resource("test03.conf"), _)) + + assertEquals(List("test01", "test02.conf", "equiv01/original.json", + "nothere", "nothere.conf", "nothere.json", "nothere.properties"), + included.map(_.name)) + } + + @Test + def includersAreUsedRecursivelyWithFiles() { + // includes.conf has recursive includes in it + val included = whatWasIncluded(Config.parse(resource("equiv03/includes.conf"), _)) + + assertEquals(List("letters/a.conf", "numbers/1.conf", "numbers/2", "letters/b.json", "letters/c"), + included.map(_.name)) + } + + @Test + def includersAreUsedWithClasspath() { + val included = whatWasIncluded(Config.parse(classOf[PublicApiTest], "/test03.conf", _)) + + assertEquals(List("test01", "test02.conf", "equiv01/original.json", + "nothere", "nothere.conf", "nothere.json", "nothere.properties"), + included.map(_.name)) + } + + @Test + def includersAreUsedRecursivelyWithClasspath() { + // includes.conf has recursive includes in it + val included = whatWasIncluded(Config.parse(classOf[PublicApiTest], "/equiv03/includes.conf", _)) + + assertEquals(List("letters/a.conf", "numbers/1.conf", "numbers/2", "letters/b.json", "letters/c"), + included.map(_.name)) + } }