diff --git a/.changes/next-release/feature-9d99ada83fb4eaabb1f0d929c0da87d73134dd3e.json b/.changes/next-release/feature-9d99ada83fb4eaabb1f0d929c0da87d73134dd3e.json new file mode 100644 index 00000000000..53a281a25ed --- /dev/null +++ b/.changes/next-release/feature-9d99ada83fb4eaabb1f0d929c0da87d73134dd3e.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Add the ability for smithy build plugins to declare that they must be run before or after other plugins. These dependencies are soft, so missing dependencies will be logged and ignored.", + "pull_requests": [ + "[#2774](https://github.com/smithy-lang/smithy/pull/2774)" + ] +} diff --git a/.changes/next-release/feature-d368237cdeb1655c7c10969e36c6aa3ca3b3034a.json b/.changes/next-release/feature-d368237cdeb1655c7c10969e36c6aa3ca3b3034a.json new file mode 100644 index 00000000000..d47a8c41748 --- /dev/null +++ b/.changes/next-release/feature-d368237cdeb1655c7c10969e36c6aa3ca3b3034a.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Add a generic dependency graph to smithy-utils to be used for sorting various dependent objects, such as integrations and plugins.", + "pull_requests": [ + "[#2774](https://github.com/smithy-lang/smithy/pull/2774)" + ] +} diff --git a/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildImpl.java b/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildImpl.java index 4df3995e3f1..75dd97c11fa 100644 --- a/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildImpl.java +++ b/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildImpl.java @@ -33,6 +33,8 @@ import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.transform.ModelTransformer; import software.amazon.smithy.model.validation.ValidatedResult; +import software.amazon.smithy.utils.CycleException; +import software.amazon.smithy.utils.DependencyGraph; import software.amazon.smithy.utils.Pair; import software.amazon.smithy.utils.SmithyBuilder; @@ -214,7 +216,8 @@ void applyAllProjections( private List resolvePlugins(String projectionName, ProjectionConfig config) { // Ensure that no two plugins use the same artifact name. Set seenArtifactNames = new HashSet<>(); - List resolvedPlugins = new ArrayList<>(); + DependencyGraph dependencyGraph = new DependencyGraph<>(); + Map resolvedPlugins = new HashMap<>(); for (Map.Entry pluginEntry : getCombinedPlugins(config).entrySet()) { PluginId id = PluginId.from(pluginEntry.getKey()); @@ -224,12 +227,45 @@ private List resolvePlugins(String projectionName, ProjectionCon id.getArtifactName(), projectionName)); } + createPlugin(projectionName, id).ifPresent(plugin -> { - resolvedPlugins.add(new ResolvedPlugin(id, plugin, pluginEntry.getValue())); + dependencyGraph.add(id); + for (String dependency : plugin.runAfter()) { + dependencyGraph.addDependency(id, PluginId.from(dependency)); + } + for (String dependant : plugin.runBefore()) { + dependencyGraph.addDependency(PluginId.from(dependant), id); + } + resolvedPlugins.put(id, new ResolvedPlugin(id, plugin, pluginEntry.getValue())); }); } - return resolvedPlugins; + List sorted; + try { + sorted = dependencyGraph.toSortedList(); + } catch (CycleException e) { + throw new SmithyBuildException(e.getMessage(), e); + } + + List result = new ArrayList<>(); + for (PluginId id : sorted) { + ResolvedPlugin resolvedPlugin = resolvedPlugins.get(id); + if (resolvedPlugin != null) { + result.add(resolvedPlugin); + continue; + } + + // If the plugin wasn't resolved, that's either because it was declared but not + // available on the classpath or not declared at all. In the former case we + // already have a log message that covers it, including a default build failure. + // If the plugin was seen, it was declared, and we don't need to log a second + // time. + if (!seenArtifactNames.contains(id.getArtifactName())) { + logMissingPluginDependency(dependencyGraph, id); + } + } + + return result; } private Map getCombinedPlugins(ProjectionConfig projection) { @@ -257,6 +293,33 @@ private Optional createPlugin(String projectionName, PluginId throw new SmithyBuildException(message); } + private void logMissingPluginDependency(DependencyGraph dependencyGraph, PluginId name) { + StringBuilder message = new StringBuilder("Could not find plugin named '"); + message.append(name).append('\''); + if (!dependencyGraph.getDirectDependants(name).isEmpty()) { + message.append(" that was supposed to run before plugins ["); + message.append(dependencyGraph.getDirectDependants(name) + .stream() + .map(PluginId::toString) + .collect(Collectors.joining(", "))); + message.append("]"); + } + if (!dependencyGraph.getDirectDependencies(name).isEmpty()) { + if (!dependencyGraph.getDirectDependants(name).isEmpty()) { + message.append(" and "); + } else { + message.append(" that "); + } + message.append("was supposed to run after plugins ["); + message.append(dependencyGraph.getDirectDependencies(name) + .stream() + .map(PluginId::toString) + .collect(Collectors.joining(", "))); + message.append("]"); + } + LOGGER.warning(message.toString()); + } + private boolean areAnyResolvedPluginsSerial(List resolvedPlugins) { for (ResolvedPlugin plugin : resolvedPlugins) { if (plugin.plugin.isSerial()) { diff --git a/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildPlugin.java b/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildPlugin.java index d7c34ff3e75..5e2ebb99761 100644 --- a/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildPlugin.java +++ b/smithy-build/src/main/java/software/amazon/smithy/build/SmithyBuildPlugin.java @@ -5,6 +5,7 @@ package software.amazon.smithy.build; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.ServiceLoader; @@ -104,4 +105,28 @@ static Function> createServiceFactory() { static Function> createServiceFactory(ClassLoader classLoader) { return createServiceFactory(ServiceLoader.load(SmithyBuildPlugin.class, classLoader)); } + + /** + * Gets the names of plugins that this plugin must come before. + * + *

Dependencies are soft. Dependencies on plugin names that cannot be found + * log a warning and are ignored. + * + * @return Returns the plugin names this must come before. + */ + default List runBefore() { + return Collections.emptyList(); + } + + /** + * Gets the name of the plugins that this plugin must come after. + * + *

Dependencies are soft. Dependencies on plugin names that cannot be found + * log a warning and are ignored. + * + * @return Returns the plugins names this must come after. + */ + default List runAfter() { + return Collections.emptyList(); + } } diff --git a/smithy-build/src/test/java/software/amazon/smithy/build/CyclicPlugin1.java b/smithy-build/src/test/java/software/amazon/smithy/build/CyclicPlugin1.java new file mode 100644 index 00000000000..a7cd768b949 --- /dev/null +++ b/smithy-build/src/test/java/software/amazon/smithy/build/CyclicPlugin1.java @@ -0,0 +1,23 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.build; + +import java.util.List; +import software.amazon.smithy.utils.ListUtils; + +public class CyclicPlugin1 implements SmithyBuildPlugin { + @Override + public String getName() { + return "cyclicplugin1"; + } + + @Override + public void execute(PluginContext context) {} + + @Override + public List runBefore() { + return ListUtils.of("cyclicplugin2"); + } +} diff --git a/smithy-build/src/test/java/software/amazon/smithy/build/CyclicPlugin2.java b/smithy-build/src/test/java/software/amazon/smithy/build/CyclicPlugin2.java new file mode 100644 index 00000000000..cfb571d8f47 --- /dev/null +++ b/smithy-build/src/test/java/software/amazon/smithy/build/CyclicPlugin2.java @@ -0,0 +1,23 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.build; + +import java.util.List; +import software.amazon.smithy.utils.ListUtils; + +public class CyclicPlugin2 implements SmithyBuildPlugin { + @Override + public String getName() { + return "cyclicplugin2"; + } + + @Override + public void execute(PluginContext context) {} + + @Override + public List runBefore() { + return ListUtils.of("cyclicplugin1"); + } +} diff --git a/smithy-build/src/test/java/software/amazon/smithy/build/SmithyBuildTest.java b/smithy-build/src/test/java/software/amazon/smithy/build/SmithyBuildTest.java index f41d4665ae6..e6fc121f251 100644 --- a/smithy-build/src/test/java/software/amazon/smithy/build/SmithyBuildTest.java +++ b/smithy-build/src/test/java/software/amazon/smithy/build/SmithyBuildTest.java @@ -13,6 +13,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.File; @@ -23,6 +24,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.time.Instant; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -539,6 +541,92 @@ private void assertPluginPresent(String pluginName, String outputFileName, Proje } } + @Test + public void topologicallySortsPlugins() throws Exception { + Map plugins = MapUtils.of( + "timestampPlugin1", + new TimestampPlugin1(), + "timestampPlugin2", + new TimestampPlugin2(), + "timestampPlugin3", + new TimestampPlugin3()); + Function> factory = SmithyBuildPlugin.createServiceFactory(); + Function> composed = name -> OptionalUtils.or( + Optional.ofNullable(plugins.get(name)), + () -> factory.apply(name)); + + SmithyBuild builder = new SmithyBuild().pluginFactory(composed); + builder.fileManifestFactory(MockManifest::new); + builder.config(SmithyBuildConfig.builder() + .load(Paths.get(getClass().getResource("topologically-sorts-plugins.json").toURI())) + .outputDirectory("/foo") + .build()); + + SmithyBuildResult results = builder.build(); + ProjectionResult source = results.getProjectionResult("source").get(); + MockManifest manifest1 = (MockManifest) source.getPluginManifest("timestampPlugin1").get(); + MockManifest manifest2 = (MockManifest) source.getPluginManifest("timestampPlugin2").get(); + MockManifest manifest3 = (MockManifest) source.getPluginManifest("timestampPlugin3").get(); + + Instant instant1 = Instant.ofEpochMilli(Long.parseLong(manifest1.getFileString("timestamp").get())); + Instant instant2 = Instant.ofEpochMilli(Long.parseLong(manifest2.getFileString("timestamp").get())); + Instant instant3 = Instant.ofEpochMilli(Long.parseLong(manifest3.getFileString("timestamp").get())); + + assertTrue(instant2.isBefore(instant1)); + assertTrue(instant3.isAfter(instant1)); + } + + @Test + public void dependenciesAreSoft() throws Exception { + Map plugins = MapUtils.of( + "timestampPlugin2", + new TimestampPlugin2(), + "timestampPlugin3", + new TimestampPlugin3()); + Function> factory = SmithyBuildPlugin.createServiceFactory(); + Function> composed = name -> OptionalUtils.or( + Optional.ofNullable(plugins.get(name)), + () -> factory.apply(name)); + + SmithyBuild builder = new SmithyBuild().pluginFactory(composed); + builder.fileManifestFactory(MockManifest::new); + builder.config(SmithyBuildConfig.builder() + .load(Paths.get(getClass().getResource("soft-plugin-dependencies.json").toURI())) + .outputDirectory("/foo") + .build()); + + SmithyBuildResult results = builder.build(); + ProjectionResult source = results.getProjectionResult("source").get(); + MockManifest manifest2 = (MockManifest) source.getPluginManifest("timestampPlugin2").get(); + MockManifest manifest3 = (MockManifest) source.getPluginManifest("timestampPlugin3").get(); + + assertTrue(manifest2.hasFile("timestamp")); + assertTrue(manifest3.hasFile("timestamp")); + } + + @Test + public void detectsPluginCycles() throws Exception { + Map plugins = MapUtils.of( + "cyclicplugin1", + new CyclicPlugin1(), + "cyclicplugin2", + new CyclicPlugin2()); + Function> factory = SmithyBuildPlugin.createServiceFactory(); + Function> composed = name -> OptionalUtils.or( + Optional.ofNullable(plugins.get(name)), + () -> factory.apply(name)); + + SmithyBuild builder = new SmithyBuild().pluginFactory(composed); + builder.fileManifestFactory(MockManifest::new); + builder.config(SmithyBuildConfig.builder() + .load(Paths.get(getClass().getResource("detects-plugin-cycles.json").toURI())) + .outputDirectory("/foo") + .build()); + + SmithyBuildException e = assertThrows(SmithyBuildException.class, builder::build); + assertThat(e.getMessage(), containsString("Cycle(s) detected")); + } + @Test public void buildCanOverrideConfigOutputDirectory() throws Exception { Path outputDirectory = Paths.get("/custom/foo"); diff --git a/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin1.java b/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin1.java new file mode 100644 index 00000000000..33257dc2be0 --- /dev/null +++ b/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin1.java @@ -0,0 +1,30 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.build; + +import java.time.Instant; + +public class TimestampPlugin1 implements SmithyBuildPlugin { + @Override + public String getName() { + return "timestampPlugin1"; + } + + @Override + public void execute(PluginContext context) { + context.getFileManifest().writeFile("timestamp", String.valueOf(Instant.now().toEpochMilli())); + try { + Thread.sleep(1); + } catch (InterruptedException ignored) {} + } + + // This is made serial to protect against test failures if we decide later to + // have plugins run in parallel. These plugins MUST run serially with respect + // to each other to function. + @Override + public boolean isSerial() { + return true; + } +} diff --git a/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin2.java b/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin2.java new file mode 100644 index 00000000000..7903f5ae36f --- /dev/null +++ b/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin2.java @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.build; + +import java.time.Instant; +import java.util.List; +import software.amazon.smithy.utils.ListUtils; + +public class TimestampPlugin2 implements SmithyBuildPlugin { + @Override + public String getName() { + return "timestampPlugin2"; + } + + @Override + public void execute(PluginContext context) { + context.getFileManifest().writeFile("timestamp", String.valueOf(Instant.now().toEpochMilli())); + try { + Thread.sleep(1); + } catch (InterruptedException ignored) {} + } + + @Override + public List runBefore() { + return ListUtils.of("timestampPlugin1"); + } + + // This is made serial to protect against test failures if we decide later to + // have plugins run in parallel. These plugins MUST run serially with respect + // to each other to function. + @Override + public boolean isSerial() { + return true; + } +} diff --git a/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin3.java b/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin3.java new file mode 100644 index 00000000000..c508062b642 --- /dev/null +++ b/smithy-build/src/test/java/software/amazon/smithy/build/TimestampPlugin3.java @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.build; + +import java.time.Instant; +import java.util.List; +import software.amazon.smithy.utils.ListUtils; + +public class TimestampPlugin3 implements SmithyBuildPlugin { + @Override + public String getName() { + return "timestampPlugin3"; + } + + @Override + public void execute(PluginContext context) { + context.getFileManifest().writeFile("timestamp", String.valueOf(Instant.now().toEpochMilli())); + try { + Thread.sleep(1); + } catch (InterruptedException ignored) {} + } + + @Override + public List runAfter() { + return ListUtils.of("timestampPlugin1"); + } + + // This is made serial to protect against test failures if we decide later to + // have plugins run in parallel. These plugins MUST run serially with respect + // to each other to function. + @Override + public boolean isSerial() { + return true; + } +} diff --git a/smithy-build/src/test/resources/software/amazon/smithy/build/detects-plugin-cycles.json b/smithy-build/src/test/resources/software/amazon/smithy/build/detects-plugin-cycles.json new file mode 100644 index 00000000000..35a3811ae52 --- /dev/null +++ b/smithy-build/src/test/resources/software/amazon/smithy/build/detects-plugin-cycles.json @@ -0,0 +1,7 @@ +{ + "version": "2.0", + "plugins": { + "cyclicplugin1": {}, + "cyclicplugin2": {} + } +} diff --git a/smithy-build/src/test/resources/software/amazon/smithy/build/soft-plugin-dependencies.json b/smithy-build/src/test/resources/software/amazon/smithy/build/soft-plugin-dependencies.json new file mode 100644 index 00000000000..cd1bedca805 --- /dev/null +++ b/smithy-build/src/test/resources/software/amazon/smithy/build/soft-plugin-dependencies.json @@ -0,0 +1,7 @@ +{ + "version": "2.0", + "plugins": { + "timestampPlugin2": {}, + "timestampPlugin3": {} + } +} diff --git a/smithy-build/src/test/resources/software/amazon/smithy/build/topologically-sorts-plugins.json b/smithy-build/src/test/resources/software/amazon/smithy/build/topologically-sorts-plugins.json new file mode 100644 index 00000000000..d70078f6d63 --- /dev/null +++ b/smithy-build/src/test/resources/software/amazon/smithy/build/topologically-sorts-plugins.json @@ -0,0 +1,8 @@ +{ + "version": "2.0", + "plugins": { + "timestampPlugin1": {}, + "timestampPlugin2": {}, + "timestampPlugin3": {} + } +} diff --git a/smithy-codegen-core/build.gradle.kts b/smithy-codegen-core/build.gradle.kts index a8bd35c07ca..1036d142cc2 100644 --- a/smithy-codegen-core/build.gradle.kts +++ b/smithy-codegen-core/build.gradle.kts @@ -4,6 +4,7 @@ */ plugins { id("smithy.module-conventions") + id("smithy.profiling-conventions") } description = "This module provides a code generation framework for generating clients, " + @@ -14,6 +15,7 @@ extra["moduleName"] = "software.amazon.smithy.codegen.core" dependencies { api(project(":smithy-utils")) + jmh(project(":smithy-utils")) api(project(":smithy-model")) api(project(":smithy-build")) } diff --git a/smithy-codegen-core/src/jmh/java/software/amazon/smithy/codegen/core/jmh/SmithyIntegrations.java b/smithy-codegen-core/src/jmh/java/software/amazon/smithy/codegen/core/jmh/SmithyIntegrations.java new file mode 100644 index 00000000000..fc1b94f04a0 --- /dev/null +++ b/smithy-codegen-core/src/jmh/java/software/amazon/smithy/codegen/core/jmh/SmithyIntegrations.java @@ -0,0 +1,230 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.codegen.core.jmh; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import software.amazon.smithy.codegen.core.CodegenContext; +import software.amazon.smithy.codegen.core.ImportContainer; +import software.amazon.smithy.codegen.core.SmithyIntegration; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolWriter; + +@Warmup(iterations = 3) +@Measurement(iterations = 3, timeUnit = TimeUnit.MICROSECONDS) +@BenchmarkMode(Mode.AverageTime) +@Fork(1) +public class SmithyIntegrations { + + @State(Scope.Thread) + public static class SmithyIntegrationsState { + public Map> independentIntegrations10; + public Map> dependentIntegrations10; + public Map> independentIntegrations100; + public Map> dependentIntegrations100; + public Map> independentIntegrations1000; + public Map> dependentIntegrations1000; + + @Setup + public void setup() { + independentIntegrations10 = new LinkedHashMap<>(10); + independentIntegrations100 = new LinkedHashMap<>(100); + independentIntegrations1000 = new LinkedHashMap<>(1000); + for (int i = 0; i < 1000; i++) { + String name = "integration" + i; + TestIntegration integration = new TestIntegration(name); + if (i < 10) { + independentIntegrations10.put(name, integration); + } + if (i < 100) { + independentIntegrations100.put(name, integration); + } + independentIntegrations1000.put(name, integration); + } + independentIntegrations10 = Collections.unmodifiableMap(independentIntegrations10); + independentIntegrations100 = Collections.unmodifiableMap(independentIntegrations100); + independentIntegrations1000 = Collections.unmodifiableMap(independentIntegrations1000); + + dependentIntegrations10 = new LinkedHashMap<>(10); + dependentIntegrations100 = new LinkedHashMap<>(100); + dependentIntegrations1000 = new LinkedHashMap<>(1000); + for (int i = 0; i < 1000; i++) { + String name = "integration" + i; + String next = "integration" + (i + 1); + String afterNext = "integration" + (i + 2); + + if (i < 10) { + List dependencies = new ArrayList<>(); + if (i < 9) { + dependencies.add(next); + } + if (i < 8) { + dependencies.add(afterNext); + } + dependentIntegrations10.put(name, + new TestIntegration(name, (byte) 0, Collections.emptyList(), dependencies)); + } + + if (i < 100) { + List dependencies = new ArrayList<>(); + if (i < 99) { + dependencies.add(next); + } + if (i < 98) { + dependencies.add(afterNext); + } + dependentIntegrations100.put(name, + new TestIntegration(name, (byte) 0, Collections.emptyList(), dependencies)); + } + + List dependencies = new ArrayList<>(); + if (i < 999) { + dependencies.add(next); + } + if (i < 998) { + dependencies.add(afterNext); + } + dependentIntegrations1000.put(name, + new TestIntegration(name, (byte) 0, Collections.emptyList(), dependencies)); + } + dependentIntegrations10 = Collections.unmodifiableMap(dependentIntegrations10); + dependentIntegrations100 = Collections.unmodifiableMap(dependentIntegrations100); + dependentIntegrations1000 = Collections.unmodifiableMap(dependentIntegrations1000); + } + } + + @Benchmark + public List> sortIndependentIntegrations10(SmithyIntegrationsState state) { + return SmithyIntegration.sort(state.independentIntegrations10.values()); + } + + @Benchmark + public List> sortIndependentIntegrations100(SmithyIntegrationsState state) { + return SmithyIntegration.sort(state.independentIntegrations100.values()); + } + + @Benchmark + public List> sortIndependentIntegrations1000(SmithyIntegrationsState state) { + return SmithyIntegration.sort(state.independentIntegrations1000.values()); + } + + @Benchmark + public List> sortDependentIntegrations10(SmithyIntegrationsState state) { + return SmithyIntegration.sort(state.dependentIntegrations10.values()); + } + + @Benchmark + public List> sortDependentIntegrations100(SmithyIntegrationsState state) { + return SmithyIntegration.sort(state.dependentIntegrations100.values()); + } + + @Benchmark + public List> sortDependentIntegrations1000(SmithyIntegrationsState state) { + return SmithyIntegration.sort(state.dependentIntegrations1000.values()); + } + + private static class IntegrationComparator implements Comparator { + + private final Map> lookup; + + IntegrationComparator(Map> lookup) { + this.lookup = lookup; + } + + @Override + public int compare(String o1, String o2) { + SmithyIntegration left = lookup.get(o1); + SmithyIntegration right = lookup.get(o2); + if (left == null || right == null) { + return 0; + } + return Byte.compare(left.priority(), right.priority()); + } + } + + private static class TestImportContainer implements ImportContainer { + @Override + public void importSymbol(Symbol symbol, String alias) { + + } + } + + private static class TestSymbolWriter extends SymbolWriter { + public TestSymbolWriter(TestImportContainer importContainer) { + super(importContainer); + } + } + + private static class TestSettings {} + + private static class TestIntegration implements SmithyIntegration< + TestSettings, + TestSymbolWriter, + CodegenContext> { + + private final String name; + private final byte priority; + private final List runBefore; + private final List runAfter; + + TestIntegration(String name) { + this(name, (byte) 0); + } + + TestIntegration(String name, byte priority) { + this(name, priority, Collections.emptyList()); + } + + TestIntegration(String name, byte priority, List runBefore) { + this(name, priority, runBefore, Collections.emptyList()); + } + + TestIntegration(String name, byte priority, List runBefore, List runAfter) { + this.name = name; + this.priority = priority; + this.runBefore = runBefore; + this.runAfter = runAfter; + } + + @Override + public String name() { + return name; + } + + @Override + public byte priority() { + return priority; + } + + @Override + public List runBefore() { + return runBefore; + } + + @Override + public List runAfter() { + return runAfter; + } + + @Override + public String toString() { + return name(); + } + } +} diff --git a/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSort.java b/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSort.java index 5bd563564fd..43e3c20494a 100644 --- a/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSort.java +++ b/smithy-codegen-core/src/main/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSort.java @@ -5,16 +5,13 @@ package software.amazon.smithy.codegen.core; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.PriorityQueue; -import java.util.Queue; -import java.util.Set; import java.util.logging.Logger; +import software.amazon.smithy.utils.CycleException; +import software.amazon.smithy.utils.DependencyGraph; final class IntegrationTopologicalSort> { @@ -22,102 +19,81 @@ final class IntegrationTopologicalSort> { private final Map integrationLookup = new LinkedHashMap<>(); private final Map insertionOrder = new HashMap<>(); - private final Map> forwardDependencies = new LinkedHashMap<>(); - private final Map> reverseDependencies = new LinkedHashMap<>(); - - private final Queue satisfied = new PriorityQueue<>((left, right) -> { - I leftIntegration = integrationLookup.get(left); - I rightIntegration = integrationLookup.get(right); - // Priority order is used to sort first. - int byteResult = Byte.compare(rightIntegration.priority(), leftIntegration.priority()); - // If priority is a tie, then sort based on insertion order of integrations. - // This makes the order deterministic. - return byteResult == 0 - ? Integer.compare(insertionOrder.get(left), insertionOrder.get(right)) - : byteResult; - }); + private final DependencyGraph dependencyGraph = new DependencyGraph<>(); IntegrationTopologicalSort(Iterable integrations) { - // Validate name conflicts and register integrations with the lookup table + insertion order table. for (I integration : integrations) { - addIntegration(integration); - } - - // Validate missing dependencies and add found dependencies. - for (I integration : integrations) { - for (String before : getValidatedDependencies("before", integration.name(), integration.runBefore())) { - addDependency(before, integration.name()); - } - for (String after : getValidatedDependencies("after", integration.name(), integration.runAfter())) { - addDependency(integration.name(), after); + I previous = this.integrationLookup.put(integration.name(), integration); + if (previous != null) { + throw new IllegalArgumentException(String.format( + "Conflicting SmithyIntegration names detected for '%s': %s and %s", + integration.name(), + integration.getClass().getCanonicalName(), + previous.getClass().getCanonicalName())); } + insertionOrder.put(integration.name(), insertionOrder.size()); + dependencyGraph.add(integration.name()); } - - // Offer satisfied dependencies. for (I integration : integrations) { - if (!forwardDependencies.containsKey(integration.name())) { - satisfied.offer(integration.name()); + dependencyGraph.addDependencies(integration.name(), integration.runAfter()); + for (String dependant : integration.runBefore()) { + dependencyGraph.addDependency(dependant, integration.name()); } } } - private void addIntegration(I integration) { - I previous = this.integrationLookup.put(integration.name(), integration); - insertionOrder.put(integration.name(), insertionOrder.size()); - if (previous != null) { - throw new IllegalArgumentException(String.format( - "Conflicting SmithyIntegration names detected for '%s': %s and %s", - integration.name(), - integration.getClass().getCanonicalName(), - previous.getClass().getCanonicalName())); + List sort() { + List result = new ArrayList<>(dependencyGraph.size()); + List sorted; + try { + sorted = dependencyGraph.toSortedList(this::compareIntegrations); + } catch (CycleException e) { + throw new IllegalArgumentException(e); } - } - - private List getValidatedDependencies(String descriptor, String what, List dependencies) { - if (dependencies.isEmpty()) { - return dependencies; - } else { - List filtered = new ArrayList<>(dependencies); - filtered.removeIf(value -> { - if (integrationLookup.containsKey(value)) { - return false; - } else { - LOGGER.warning(what + " is supposed to run " + descriptor + " an integration that could " - + "not be found, '" + value + "'"); - return true; - } - }); - return filtered; + for (String name : sorted) { + I integration = integrationLookup.get(name); + if (integration != null) { + result.add(integration); + } else { + logMissingIntegration(name); + } } + return result; } - private void addDependency(String what, String dependsOn) { - forwardDependencies.computeIfAbsent(what, n -> new LinkedHashSet<>()).add(dependsOn); - reverseDependencies.computeIfAbsent(dependsOn, n -> new LinkedHashSet<>()).add(what); + private int compareIntegrations(String left, String right) { + I leftIntegration = integrationLookup.get(left); + I rightIntegration = integrationLookup.get(right); + if (leftIntegration == null || rightIntegration == null) { + return 0; + } + // Priority order is used to sort first. + int byteResult = Byte.compare(rightIntegration.priority(), leftIntegration.priority()); + // If priority is a tie, then sort based on insertion order of integrations. + // This makes the order deterministic. + return byteResult == 0 + ? Integer.compare(insertionOrder.get(left), insertionOrder.get(right)) + : byteResult; } - List sort() { - List result = new ArrayList<>(); - - while (!satisfied.isEmpty()) { - String current = satisfied.poll(); - forwardDependencies.remove(current); - result.add(integrationLookup.get(current)); - - for (String dependent : reverseDependencies.getOrDefault(current, Collections.emptySet())) { - Set dependentDependencies = forwardDependencies.get(dependent); - dependentDependencies.remove(current); - if (dependentDependencies.isEmpty()) { - satisfied.offer(dependent); - } - } + private void logMissingIntegration(String name) { + StringBuilder message = new StringBuilder("Could not find SmithyIntegration named '"); + message.append(name).append('\''); + if (!dependencyGraph.getDirectDependants(name).isEmpty()) { + message.append(" that was supposed to run before integrations ["); + message.append(String.join(", ", dependencyGraph.getDirectDependants(name))); + message.append("]"); } - - if (!forwardDependencies.isEmpty()) { - throw new IllegalArgumentException("SmithyIntegration cycles detected among " - + forwardDependencies.keySet()); + if (!dependencyGraph.getDirectDependencies(name).isEmpty()) { + if (!dependencyGraph.getDirectDependants(name).isEmpty()) { + message.append(" and "); + } else { + message.append(" that "); + } + message.append("was supposed to run after integrations ["); + message.append(String.join(", ", dependencyGraph.getDirectDependencies(name))); + message.append("]"); } - - return result; + LOGGER.warning(message.toString()); } } diff --git a/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSortTest.java b/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSortTest.java index 2f73723ef4f..eb97efda205 100644 --- a/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSortTest.java +++ b/smithy-codegen-core/src/test/java/software/amazon/smithy/codegen/core/IntegrationTopologicalSortTest.java @@ -6,6 +6,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import java.util.ArrayList; @@ -171,7 +172,7 @@ public void detectsCycles() { RuntimeException e = Assertions.assertThrows(IllegalArgumentException.class, () -> SmithyIntegration.sort(integrations)); - assertThat(e.getMessage(), equalTo("SmithyIntegration cycles detected among [b, d]")); + assertThat(e.getMessage(), containsString("[b, d]")); } @Test diff --git a/smithy-utils/src/main/java/software/amazon/smithy/utils/CycleException.java b/smithy-utils/src/main/java/software/amazon/smithy/utils/CycleException.java new file mode 100644 index 00000000000..b27e989d232 --- /dev/null +++ b/smithy-utils/src/main/java/software/amazon/smithy/utils/CycleException.java @@ -0,0 +1,80 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.utils; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Signals that one or more cycles have been detected when attempting to topologically + * sort shapes in a {@link DependencyGraph}. + */ +public class CycleException extends RuntimeException { + private final List sortedNodes; + private final Set cyclicNodes; + private final Class nodeType; + + /** + * Constructs a CycleException. + * + * @param sortedNodes A list of the nodes that were sorted successfully. + * @param cyclicNodes A set of nodes that could not be sorted due to being part of a cycle. + * @param The type of the node. + */ + public CycleException(List sortedNodes, Set cyclicNodes) { + super(String.format("Cycle(s) detected amongst: [%s]", + cyclicNodes.stream().map(Object::toString).collect(Collectors.joining(", ")))); + this.sortedNodes = ListUtils.copyOf(sortedNodes); + this.cyclicNodes = SetUtils.orderedCopyOf(cyclicNodes); + if (this.cyclicNodes.isEmpty()) { + throw new IllegalArgumentException("Cyclic nodes cannot be empty"); + } + this.nodeType = this.cyclicNodes.iterator().next().getClass(); + } + + /** + * Gets the set of nodes that are part of a cycle. + * + *

This contains all nodes that are a part of any cycles. To see a list of + * individual cycles, use {@link DependencyGraph#findCycles()}. + * + * @param expectedNodeType The expected type of the node, which will be checked to + * be compatible with the actual type. This is necessary because + * exceptions can't be generic. + * @return Returns a set of cyclic nodes. + * @param The type of the graph's nodes. + */ + @SuppressWarnings("unchecked") + public Set getCyclicNodes(Class expectedNodeType) { + if (expectedNodeType.isAssignableFrom(this.nodeType)) { + return (Set) cyclicNodes; + } + throw new IllegalArgumentException(String.format( + "Expected node type %s is not assignable from actual node type %s", + expectedNodeType.getName(), + this.nodeType.getName())); + } + + /** + * Gets the list of nodes that could be sorted. + * + * @param expectedNodeType The expected type of the node, which will be checked to + * be compatible with the actual type. This is necessary because + * exceptions can't be generic. + * @return Returns the sorted list of non-cyclic nodes. + * @param The type of the graph's nodes. + */ + @SuppressWarnings("unchecked") + public List getSortedNodes(Class expectedNodeType) { + if (expectedNodeType.isAssignableFrom(this.nodeType)) { + return (List) sortedNodes; + } + throw new IllegalArgumentException(String.format( + "Expected node type %s is not assignable from actual node type %s", + expectedNodeType.getName(), + this.nodeType.getName())); + } +} diff --git a/smithy-utils/src/main/java/software/amazon/smithy/utils/DependencyGraph.java b/smithy-utils/src/main/java/software/amazon/smithy/utils/DependencyGraph.java new file mode 100644 index 00000000000..c76deb4ebe1 --- /dev/null +++ b/smithy-utils/src/main/java/software/amazon/smithy/utils/DependencyGraph.java @@ -0,0 +1,404 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.utils; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Set; + +/** + * A basic dependency graph. + * + *

Iteration and ordering in collection methods are based on insertion order. + * A topographically sorted view of the graph is provided by {@link #toSortedList()}. + * + * @param The data type stored in the graph. + */ +public class DependencyGraph implements Collection, Iterable { + + private final Map> forwardDependencies; + private final Map> reverseDependencies; + + /** + * Constructs a new, empty dependency graph with the specified initial capacity. + * + * @param initialCapacity the initial capacity of the dependency graph. + */ + public DependencyGraph(int initialCapacity) { + this.forwardDependencies = new LinkedHashMap<>(initialCapacity); + this.reverseDependencies = new LinkedHashMap<>(initialCapacity); + } + + /** + * Constructs a new dependency graph with the same elements as the specified + * collection. The dependency graph is created with an initial capacity equal + * to the size of the collection. + * + * @param c the collection whose elements are to be added to this graph. + */ + public DependencyGraph(Collection c) { + this(c.size()); + this.addAll(c); + } + + /** + * Constructs a new dependency graph with the default initial capacity (16). + */ + public DependencyGraph() { + this(16); + } + + @Override + public boolean add(T t) { + if (forwardDependencies.containsKey(t)) { + return false; + } + forwardDependencies.put(t, new LinkedHashSet<>()); + reverseDependencies.put(t, new LinkedHashSet<>()); + return true; + } + + /** + * Adds a dependency between two nodes. + * + *

If either node is not already present in the graph, it is added. + * + * @param what The node to add a dependency to. + * @param dependsOn The dependency to add. + */ + public void addDependency(T what, T dependsOn) { + if (!forwardDependencies.containsKey(what)) { + add(what); + } + if (!forwardDependencies.containsKey(dependsOn)) { + add(dependsOn); + } + forwardDependencies.get(what).add(dependsOn); + reverseDependencies.get(dependsOn).add(what); + } + + /** + * Adds a set of dependencies to a node. + * + *

If any node is not already present in the graph, it is added. + * + * @param what The node to add dependencies to. + * @param dependsOn The dependencies to add. + */ + public void addDependencies(T what, Collection dependsOn) { + if (!forwardDependencies.containsKey(what)) { + add(what); + } + Set dependencies = forwardDependencies.get(what); + for (T dependency : dependsOn) { + if (!forwardDependencies.containsKey(dependency)) { + add(dependency); + } + dependencies.add(dependency); + reverseDependencies.get(dependency).add(what); + } + } + + @Override + public boolean remove(Object o) { + if (!forwardDependencies.containsKey(o)) { + return false; + } + for (T dependant : reverseDependencies.get(o)) { + forwardDependencies.get(dependant).remove(o); + } + forwardDependencies.remove(o); + reverseDependencies.remove(o); + return true; + } + + /** + * Removes a dependency from a node. + * + * @param what The node to remove a dependency from. + * @param dependsOn The dependency to remove. + */ + public void removeDependency(T what, T dependsOn) { + forwardDependencies.get(what).remove(dependsOn); + reverseDependencies.get(dependsOn).remove(what); + } + + /** + * Removes a set of dependencies from a node. + * + * @param what The node to remove dependencies from. + * @param dependsOn The dependencies to remove. + */ + public void removeDependencies(T what, Collection dependsOn) { + forwardDependencies.get(what).removeAll(dependsOn); + for (T dependency : dependsOn) { + reverseDependencies.get(dependency).remove(what); + } + } + + /** + * @return Returns a set of nodes that have no remaining dependencies. + */ + public Set getIndependentNodes() { + Set result = new LinkedHashSet<>(); + for (Map.Entry> node : forwardDependencies.entrySet()) { + if (node.getValue().isEmpty()) { + result.add(node.getKey()); + } + } + return result; + } + + /** + * Gets all the direct dependencies of a given node. + * + * @param node The node whose dependencies should be fetched. + * @return A set of dependencies. + */ + public Set getDirectDependencies(T node) { + Set result = forwardDependencies.get(node); + if (result == null) { + return null; + } + return SetUtils.copyOf(result); + } + + /** + * Gets all the nodes that depend on a given node. + * + * @param node The node whose dependants should be fetched. + * @return A set of dependants. + */ + public Set getDirectDependants(T node) { + Set result = reverseDependencies.get(node); + if (result == null) { + return null; + } + return SetUtils.copyOf(result); + } + + /** + * Finds all strongly-connected components of the graph. + * + *

Cycles returned are not *elementary* cycles. That is, if two or more + * cycles share any nodes, they will be returned as a single cycle. For + * example, take the following graph: + * + *

+     *     A -> B -> C -> A
+     *          B -> D -> B
+     * 
+ * + *

This graph has one set of strongly-connected components + * ({@code {B, C, A, D}}) made up of two elementary cycles + * ({@code {B, C, A}} and {@code D, B}). This method will return the + * set of strongly-connected components, {@code {B, C, A, D}}. + * + * @return A list of all strongly-connected components in the graph. + */ + public List> findCycles() { + List> cycles = new ArrayList<>(); + Map indexes = new HashMap<>(size()); + Map lowLinks = new HashMap<>(size()); + Deque stack = new ArrayDeque<>(size()); + Set onStack = new HashSet<>(size()); + int index = 0; + + for (T node : reverseDependencies.keySet()) { + if (!indexes.containsKey(node)) { + index = strongConnect(node, indexes, lowLinks, stack, onStack, cycles, index); + } + } + return cycles; + } + + private int strongConnect( + T current, + Map indexes, + Map lowLinks, + Deque stack, + Set onStack, + List> cycles, + int index + ) { + indexes.put(current, index); + lowLinks.put(current, index); + index++; + stack.push(current); + onStack.add(current); + + for (T dependent : reverseDependencies.get(current)) { + if (!indexes.containsKey(dependent)) { + index = strongConnect(dependent, indexes, lowLinks, stack, onStack, cycles, index); + lowLinks.put(current, Math.min(lowLinks.get(current), lowLinks.get(dependent))); + } else if (onStack.contains(dependent)) { + lowLinks.put(current, Math.min(lowLinks.get(current), indexes.get(dependent))); + } + } + + if (lowLinks.get(current).equals(indexes.get(current))) { + List cycle = new ArrayList<>(); + T node; + do { + node = stack.pop(); + onStack.remove(node); + cycle.add(node); + } while (!node.equals(current)); + if (cycle.size() > 1) { + cycles.add(cycle); + } + } + return index; + } + + @Override + public int size() { + return forwardDependencies.size(); + } + + @Override + public boolean isEmpty() { + return forwardDependencies.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return forwardDependencies.containsKey(o); + } + + @Override + public Iterator iterator() { + return forwardDependencies.keySet().iterator(); + } + + @Override + public Object[] toArray() { + return forwardDependencies.keySet().toArray(); + } + + @Override + public T1[] toArray(T1[] a) { + return forwardDependencies.keySet().toArray(a); + } + + /** + * Gets a topographically sorted view of the graph. + * + * @return Returns a topographically sorted list view of the graph. + */ + public List toSortedList() { + return toSortedList(new ArrayDeque<>()); + } + + /** + * Gets a topographically sorted view of the graph where independent + * nodes are evaluated in the order given by the comparator. + * + * @param comparator A comparator used to sort independent nodes. + * @return Returns a topographically sorted list view of the graph. + */ + public List toSortedList(Comparator comparator) { + return toSortedList(new PriorityQueue<>(comparator)); + } + + private List toSortedList(Queue satisfied) { + // Create a mapping of dependency counts so we don't have to modify the graph. + Map inDegree = new HashMap<>(forwardDependencies.size()); + for (Map.Entry> entry : forwardDependencies.entrySet()) { + int degree = entry.getValue().size(); + inDegree.put(entry.getKey(), degree); + + // If the node has no dependencies, go ahead and add it to the queue. + if (entry.getValue().isEmpty()) { + satisfied.offer(entry.getKey()); + } + } + + List result = new ArrayList<>(forwardDependencies.size()); + + // Process nodes in priority order. + while (!satisfied.isEmpty()) { + T node = satisfied.poll(); + result.add(node); + + // For each dependent node, decrease its dependency count by one. + for (T dependent : reverseDependencies.get(node)) { + int newCount = inDegree.get(dependent) - 1; + inDegree.put(dependent, newCount); + + // If all dependencies are satisfied, add the dependent to the queue. + if (newCount == 0) { + satisfied.add(dependent); + } + } + } + + // Check for cycles. + if (result.size() != reverseDependencies.size()) { + Set remaining = new LinkedHashSet<>(reverseDependencies.size() - result.size()); + for (T node : reverseDependencies.keySet()) { + if (!result.contains(node)) { + remaining.add(node); + } + } + throw new CycleException(result, remaining); + } + + return result; + } + + @Override + public boolean containsAll(Collection c) { + return forwardDependencies.keySet().containsAll(c); + } + + @Override + public boolean addAll(Collection c) { + boolean changed = false; + for (T element : c) { + changed = add(element) || changed; + } + return changed; + } + + @Override + public boolean removeAll(Collection c) { + boolean changed = false; + for (Object element : c) { + changed = remove(element) || changed; + } + return changed; + } + + @Override + public boolean retainAll(Collection c) { + List toRemove = new ArrayList<>(); + for (T element : forwardDependencies.keySet()) { + if (!c.contains(element)) { + toRemove.add(element); + } + } + return this.removeAll(toRemove); + } + + @Override + public void clear() { + forwardDependencies.clear(); + reverseDependencies.clear(); + } + +} diff --git a/smithy-utils/src/test/java/software/amazon/smithy/utils/CycleExceptionTest.java b/smithy-utils/src/test/java/software/amazon/smithy/utils/CycleExceptionTest.java new file mode 100644 index 00000000000..facadcc01a6 --- /dev/null +++ b/smithy-utils/src/test/java/software/amazon/smithy/utils/CycleExceptionTest.java @@ -0,0 +1,33 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.utils; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; + +public class CycleExceptionTest { + @Test + public void requiresCycles() { + assertThrows(IllegalArgumentException.class, + () -> new CycleException(Collections.emptyList(), Collections.emptySet())); + } + + @Test + public void checksNodeTypes() { + CycleException cycleException = new CycleException(ListUtils.of("foo"), SetUtils.of("bar")); + List sortedStrings = cycleException.getSortedNodes(String.class); + Set cyclicStrings = cycleException.getCyclicNodes(String.class); + + List sortedObjects = cycleException.getSortedNodes(Object.class); + Set cyclicObjects = cycleException.getCyclicNodes(Object.class); + + assertThrows(IllegalArgumentException.class, () -> cycleException.getSortedNodes(Integer.class)); + assertThrows(IllegalArgumentException.class, () -> cycleException.getCyclicNodes(Integer.class)); + } +} diff --git a/smithy-utils/src/test/java/software/amazon/smithy/utils/DependencyGraphTest.java b/smithy-utils/src/test/java/software/amazon/smithy/utils/DependencyGraphTest.java new file mode 100644 index 00000000000..49d6a7035c3 --- /dev/null +++ b/smithy-utils/src/test/java/software/amazon/smithy/utils/DependencyGraphTest.java @@ -0,0 +1,276 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.utils; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class DependencyGraphTest { + @Test + public void addsNodes() { + DependencyGraph graph = new DependencyGraph<>(); + assertTrue(graph.isEmpty()); + + assertTrue(graph.add("foo")); + assertFalse(graph.isEmpty()); + assertEquals(1, graph.size()); + assertThat(graph, contains("foo")); + + assertFalse(graph.add("foo")); + assertEquals(1, graph.size()); + assertThat(graph, contains("foo")); + + List newNodes = ListUtils.of("foo", "bar", "baz"); + assertTrue(graph.addAll(newNodes)); + assertFalse(graph.isEmpty()); + assertEquals(3, graph.size()); + assertThat(graph, contains("foo", "bar", "baz")); + + assertFalse(graph.addAll(newNodes)); + assertEquals(3, graph.size()); + assertThat(graph, contains("foo", "bar", "baz")); + } + + @Test + public void constructsFromCollection() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + assertFalse(graph.isEmpty()); + assertEquals(3, graph.size()); + assertThat(graph, contains("foo", "bar", "baz")); + } + + @Test + public void createsArrays() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + String[] expected = new String[] {"foo", "bar", "baz"}; + assertArrayEquals(expected, graph.toArray()); + assertArrayEquals(expected, graph.toArray(new String[0])); + } + + @Test + public void iteratesAllElements() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + List iterated = new ArrayList<>(); + for (String node : graph) { + iterated.add(node); + } + assertEquals(nodes, iterated); + } + + @Test + public void addsDependency() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + + graph.addDependency("foo", "bar"); + assertThat(graph.getDirectDependencies("foo"), contains("bar")); + assertThat(graph.getDirectDependants("bar"), contains("foo")); + } + + @Test + public void addsDependencies() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + + graph.addDependencies("foo", ListUtils.of("bar", "baz")); + assertThat(graph.getDirectDependencies("foo"), contains("bar", "baz")); + assertThat(graph.getDirectDependants("bar"), contains("foo")); + assertThat(graph.getDirectDependants("baz"), contains("foo")); + } + + @Test + public void addDependenciesAddsMissingNodes() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependency("spam", "eggs"); + assertThat(graph, contains("spam", "eggs")); + + graph = new DependencyGraph<>(); + graph.addDependencies("foo", ListUtils.of("bar", "baz")); + assertThat(graph, contains("foo", "bar", "baz")); + } + + @Test + public void removesNodes() { + DependencyGraph graph = new DependencyGraph<>(); + graph.add("foo"); + assertThat(graph, contains("foo")); + + assertTrue(graph.remove("foo")); + assertTrue(graph.isEmpty()); + assertFalse(graph.remove("foo")); + + graph.addAll(ListUtils.of("foo", "bar", "baz")); + assertThat(graph, contains("foo", "bar", "baz")); + assertTrue(graph.removeAll(ListUtils.of("foo", "bar", "baz"))); + assertTrue(graph.isEmpty()); + assertFalse(graph.removeAll(ListUtils.of("foo", "bar", "baz"))); + + graph.addAll(ListUtils.of("foo", "bar", "baz")); + assertThat(graph, contains("foo", "bar", "baz")); + graph.clear(); + assertTrue(graph.isEmpty()); + } + + @Test + public void removedNodesAreRemovedFromDependencies() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependency("foo", "bar"); + assertThat(graph, contains("foo", "bar")); + assertThat(graph.getDirectDependencies("foo"), contains("bar")); + assertThat(graph.getDirectDependants("bar"), contains("foo")); + + graph.remove("bar"); + assertThat(graph, contains("foo")); + assertTrue(graph.getDirectDependencies("foo").isEmpty()); + assertNull(graph.getDirectDependants("bar")); + } + + @Test + public void removesDependency() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependency("spam", "eggs"); + assertThat(graph, contains("spam", "eggs")); + assertThat(graph.getDirectDependencies("spam"), contains("eggs")); + + graph.removeDependency("spam", "eggs"); + assertThat(graph, contains("spam", "eggs")); + assertTrue(graph.getDirectDependencies("spam").isEmpty()); + } + + @Test + public void removesDependencies() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependencies("foo", ListUtils.of("bar", "baz")); + assertThat(graph, contains("foo", "bar", "baz")); + assertThat(graph.getDirectDependencies("foo"), contains("bar", "baz")); + + graph.removeDependencies("foo", ListUtils.of("bar", "baz")); + assertThat(graph, contains("foo", "bar", "baz")); + assertTrue(graph.getDirectDependencies("foo").isEmpty()); + } + + @Test + public void retainsNodes() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + assertThat(graph, contains("foo", "bar", "baz")); + + assertTrue(graph.retainAll(ListUtils.of("foo", "baz"))); + assertThat(graph, contains("foo", "baz")); + assertFalse(graph.retainAll(ListUtils.of("foo", "baz"))); + } + + @Test + public void returnsNullEdgesForMissingNodes() { + DependencyGraph graph = new DependencyGraph<>(); + assertTrue(graph.isEmpty()); + assertNull(graph.getDirectDependencies("foo")); + assertNull(graph.getDirectDependants("bar")); + } + + @Test + public void getsIndependentNodes() { + DependencyGraph graph = new DependencyGraph<>(); + assertTrue(graph.getIndependentNodes().isEmpty()); + + graph.addDependency("spam", "eggs"); + graph.add("foo"); + assertThat(graph, contains("spam", "eggs", "foo")); + + assertThat(graph.getIndependentNodes(), contains("eggs", "foo")); + } + + @Test + public void topologicallySorts() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + graph.addDependency("bar", "baz"); + List actual = graph.toSortedList(); + List expected = ListUtils.of("foo", "baz", "bar"); + assertEquals(expected, actual); + } + + @Test + public void sortsComplexGraph() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependencies("a", ListUtils.of("b", "c")); + graph.addDependencies("b", ListUtils.of("c", "d")); + graph.addDependency("c", "d"); + List actual = graph.toSortedList(); + List expected = ListUtils.of("d", "c", "b", "a"); + assertEquals(expected, actual); + } + + @Test + public void topologicallySortsWithCustomComparator() { + List nodes = ListUtils.of("foo", "bar", "baz"); + DependencyGraph graph = new DependencyGraph<>(nodes); + graph.addDependency("bar", "baz"); + List actual = graph.toSortedList(String.CASE_INSENSITIVE_ORDER); + List expected = ListUtils.of("baz", "bar", "foo"); + assertEquals(expected, actual); + } + + @Test() + public void sortedListThrowsErrorOnCycle() { + DependencyGraph graph = new DependencyGraph<>(); + graph.add("foo"); + graph.addDependency("bar", "foo"); + graph.addDependency("spam", "eggs"); + graph.addDependency("eggs", "spam"); + + CycleException exception = assertThrows(CycleException.class, graph::toSortedList); + assertThat(exception.getSortedNodes(String.class), contains("foo", "bar")); + assertThat(exception.getCyclicNodes(String.class), containsInAnyOrder("spam", "eggs")); + } + + @Test + public void detectsSimpleCycle() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependency("spam", "eggs"); + graph.addDependency("eggs", "spam"); + assertThat(graph.findCycles(), contains(ListUtils.of("eggs", "spam"))); + } + + @Test + public void detectsComplexCycle() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependency("a", "b"); + graph.addDependency("b", "c"); + graph.addDependency("c", "d"); + graph.addDependency("d", "b"); + assertThat(graph.findCycles(), contains(ListUtils.of("c", "d", "b"))); + } + + @Test + public void detectsMultipleCycles() { + DependencyGraph graph = new DependencyGraph<>(); + graph.addDependency("spam", "eggs"); + graph.addDependency("eggs", "spam"); + graph.addDependency("a", "b"); + graph.addDependency("b", "c"); + graph.addDependency("c", "d"); + graph.addDependency("d", "b"); + assertThat(graph.findCycles(), + containsInAnyOrder( + ListUtils.of("eggs", "spam"), + ListUtils.of("c", "d", "b"))); + + } +}