diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java new file mode 100644 index 000000000..d92d0ccd6 --- /dev/null +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java @@ -0,0 +1,31 @@ +package org.datavec.api.transform.ops; + +import com.tngtech.archunit.core.importer.ImportOption; +import com.tngtech.archunit.junit.AnalyzeClasses; +import com.tngtech.archunit.junit.ArchTest; +import com.tngtech.archunit.junit.ArchUnitRunner; +import com.tngtech.archunit.lang.ArchRule; +import org.junit.runner.RunWith; +import org.nd4j.common.tests.BaseND4JTest; + +import java.io.Serializable; + +import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes; + +/** + * Created by dariuszzbyrad on 7/31/2020. + */ +@RunWith(ArchUnitRunner.class) +@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class}) +public class AggregableMultiOpArchTest extends BaseND4JTest { + + @ArchTest + public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes() + .that().resideInAPackage("org.datavec.api.transform.ops") + .and().doNotHaveSimpleName("AggregatorImpls") + .and().doNotHaveSimpleName("IAggregableReduceOp") + .and().doNotHaveSimpleName("StringAggregatorImpls") + .and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1") + .should().implement(Serializable.class) + .because("All aggregate ops must be serializable."); +} \ No newline at end of file diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index 3e032f329..92d7d8dbc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -68,91 +68,4 @@ public class AggregableMultiOpTest extends BaseND4JTest { assertTrue(combinedRes.get(1).toDouble() == 90D); assertTrue(combinedRes.get(0).toInt() == 1); } - - @Test - public void testAllAggregateOpsAreSerializable() throws Exception { - Set allTypes = new HashSet<>(); - allTypes.add("org.datavec.api.transform.ops.LongWritableOp"); - allTypes.add("org.datavec.api.transform.ops.IntWritableOp"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableMean"); - allTypes.add("org.datavec.api.transform.ops.StringAggregatorImpls$AggregableStringReduce"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableRange"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImplsTest"); - allTypes.add("org.datavec.api.transform.ops.DispatchWithConditionOp"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableVariance"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls"); - allTypes.add("org.datavec.api.transform.ops.FloatWritableOp"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableProd"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableLast"); - allTypes.add("org.datavec.api.transform.ops.StringAggregatorImpls$AggregableStringPrepend"); - allTypes.add("org.datavec.api.transform.ops.ByteWritableOp"); - allTypes.add("org.datavec.api.transform.ops.AggregableMultiOpTest"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableStdDev"); - allTypes.add("org.datavec.api.transform.ops.StringAggregatorImpls$1"); - allTypes.add("org.datavec.api.transform.ops.DispatchOp"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableMin"); - allTypes.add("org.datavec.api.transform.ops.StringAggregatorImpls$AggregableStringAppend"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableCount"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableSum"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregablePopulationVariance"); - allTypes.add("org.datavec.api.transform.ops.AggregableCheckingOp"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableMax"); - allTypes.add("org.datavec.api.transform.ops.AggregableMultiOp"); - allTypes.add("org.datavec.api.transform.ops.IAggregableReduceOp"); - allTypes.add("org.datavec.api.transform.ops.DispatchOpTest"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableCountUnique"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableUncorrectedStdDev"); - allTypes.add("org.datavec.api.transform.ops.StringWritableOp"); - allTypes.add("org.datavec.api.transform.ops.StringAggregatorImpls"); - allTypes.add("org.datavec.api.transform.ops.DoubleWritableOp"); - allTypes.add("org.datavec.api.transform.ops.AggregatorImpls$AggregableFirst"); - - Set ops = new HashSet<>(); - - for (String type : allTypes) { - if (type.startsWith("org.datavec.api.transform.ops")) { - if (type.endsWith("Op")) { - ops.add(type); - } - - if (type.contains("Aggregable") && !type.endsWith("Test")) { - ops.add(type); - } - } - } - - for (String op : ops) { - Class cls = Class.forName(op); - assertTrue(op + " should implement Serializable", implementsSerializable(cls)); - } - } - - private boolean implementsSerializable(Class cls) { - if (cls == null) { - return false; - } - if (cls == Serializable.class) { - return true; - } - - Class[] interfaces = cls.getInterfaces(); - Set> parents = new HashSet<>(); - parents.add(cls.getSuperclass()); - - for (Class anInterface : interfaces) { - Collections.addAll(parents, anInterface.getInterfaces()); - - if (anInterface.equals(Serializable.class)) { - return true; - } - } - - for (Class parent : parents) { - if (implementsSerializable(parent)) { - return true; - } - } - - return false; - } } diff --git a/datavec/pom.xml b/datavec/pom.xml index 8e403ea1e..c6e342ace 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -75,6 +75,12 @@ ${junit.version} test + + com.tngtech.archunit + archunit-junit4 + ${archunit.version} + test + org.projectlombok lombok diff --git a/pom.xml b/pom.xml index 9c9d771fb..f6078f993 100644 --- a/pom.xml +++ b/pom.xml @@ -333,6 +333,7 @@ 2.0.29 1.7.21 4.12 + 0.14.1 1.2.3 2.10.1 2.10.3