From 83f88a1b0d5c16607680f5fd021533291342c9da Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sun, 26 Apr 2020 12:27:11 +1000 Subject: [PATCH] [WIP] Add test: "all ops in namespace" scanner/check (#413) * #8879 'All ops in namespace' scanner first steps Signed-off-by: Alex Black * Clean up and polish op mapping check utility Signed-off-by: Alex Black --- .../java/org/nd4j/autodiff/TestOpMapping.java | 219 +++++++++++++++++- 1 file changed, 215 insertions(+), 4 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index e88f195c0..adac58697 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,8 @@ package org.nd4j.autodiff; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.FilenameUtils; import org.junit.Ignore; import org.junit.Test; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -23,20 +26,68 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.ImportClassMapping; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.NoOp; +import org.nd4j.linalg.api.ops.compat.CompatSparseToDense; +import org.nd4j.linalg.api.ops.compat.CompatStringSplit; +import org.nd4j.linalg.api.ops.custom.*; +import org.nd4j.linalg.api.ops.impl.broadcast.*; +import org.nd4j.linalg.api.ops.impl.broadcast.bool.*; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; +import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DTF; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DTF; +import org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp; +import org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp; +import org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp; +import org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp; +import org.nd4j.linalg.api.ops.impl.nlp.CbowRound; +import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound; +import org.nd4j.linalg.api.ops.impl.reduce.HashCode; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue; +import org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent; +import org.nd4j.linalg.api.ops.impl.shape.Create; +import org.nd4j.linalg.api.ops.impl.shape.ParallelStack; +import org.nd4j.linalg.api.ops.impl.transforms.any.Assign; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ParallelConcat; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; +import org.nd4j.linalg.api.ops.impl.updaters.*; +import org.nd4j.linalg.api.ops.persistence.RestoreV2; +import org.nd4j.linalg.api.ops.persistence.SaveV2; +import org.nd4j.linalg.api.ops.util.PrintAffinity; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.resources.Resources; import org.reflections.Reflections; +import java.io.File; import java.lang.reflect.Modifier; +import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; public class TestOpMapping extends BaseNd4jTest { + Set> subTypes; + public TestOpMapping(Nd4jBackend b){ super(b); + + Reflections reflections = new Reflections("org.nd4j"); + subTypes = reflections.getSubTypesOf(DifferentialFunction.class); } @Override @@ -46,14 +97,13 @@ public class TestOpMapping extends BaseNd4jTest { @Override public long getTimeoutMilliseconds() { - return 60000L; + return 90000L; } + + @Test public void testOpMappingCoverage() throws Exception { - Reflections reflections = new Reflections("org.nd4j"); - Set> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); - Map opNameMapping = ImportClassMapping.getOpNameMapping(); Map tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions(); Map onnxOpNameMapping = ImportClassMapping.getOnnxOpMappingFunctions(); @@ -102,6 +152,167 @@ public class TestOpMapping extends BaseNd4jTest { } } + @Test + public void testOpsInNamespace() throws Exception { + //Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't + // want to add to a namespace for some reason) + //Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops + + String path = FilenameUtils.concat(new File("").getAbsolutePath(), "../nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops"); + path = FilenameUtils.normalize(path); + System.out.println(path); + File dir = new File(path); + Collection c = FileUtils.listFiles(dir, new String[]{"java"}, true); + + String strPattern = " org.nd4j.linalg.api.ops(\\.(\\w)+)+"; + Pattern pattern = Pattern.compile(strPattern); + + + Set seenClasses = new HashSet<>(); + for(File f1 : c){ + List lines = FileUtils.readLines(f1, StandardCharsets.UTF_8); + for(String l : lines){ + Matcher matcher = pattern.matcher(l); + while(matcher.find()){ + int s = matcher.start(); + int e = matcher.end(); + + String str = l.substring(s+1,e); //+1 because pattern starts with space + seenClasses.add(str); + } + } + } + + int countNotSeen = 0; + int countSeen = 0; + List notSeen = new ArrayList<>(); + for(Class cl : subTypes){ + String s = cl.getName(); + + //Backprop/gradient ops should not be in namespaces + if(s.endsWith("Bp") || s.endsWith("BpOp") || s.endsWith("Gradient") || s.endsWith("Derivative") || s.endsWith("Grad")) + continue; + + if(Modifier.isAbstract(cl.getModifiers()) || Modifier.isInterface(cl.getModifiers())) //Skip interfaces and abstract methods + continue; + + if(excludedFromNamespaces.contains(cl)) //Explicitly excluded - don't want in namespaces + continue; + + if(!seenClasses.contains(s)){ +// System.out.println("NOT SEEN: " + s); + notSeen.add(s); + countNotSeen++; + } else { + countSeen++; + } + } + + Collections.sort(notSeen); + System.out.println(String.join("\n", notSeen)); + + System.out.println("Not seen ops: " + countNotSeen); + System.out.println("Seen ops: " + countSeen); + System.out.println("Namespace scan count ops: " + seenClasses.size()); + } + + //Set of classes that we explicitly don't want in a namespace for some reason + private static final Set> excludedFromNamespaces = new HashSet<>(); + static { + Set> s = excludedFromNamespaces; + + //Updaters - used via TrainingConfig, not namespaces + s.add(AdaDeltaUpdater.class); + s.add(AdaGradUpdater.class); + s.add(AdaMaxUpdater.class); + s.add(AdamUpdater.class); + s.add(AmsGradUpdater.class); + s.add(NadamUpdater.class); + s.add(NesterovsUpdater.class); + s.add(RmsPropUpdater.class); + s.add(SgdUpdater.class); + + //Legacy broadcast ops + s.add(BroadcastAddOp.class); + s.add(BroadcastAMax.class); + s.add(BroadcastAMin.class); + s.add(BroadcastCopyOp.class); + s.add(BroadcastDivOp.class); + s.add(BroadcastGradientArgs.class); + s.add(BroadcastMax.class); + s.add(BroadcastMin.class); + s.add(BroadcastMulOp.class); + s.add(BroadcastRDivOp.class); + s.add(BroadcastRSubOp.class); + s.add(BroadcastSubOp.class); + s.add(BroadcastTo.class); + s.add(BroadcastEqualTo.class); + s.add(BroadcastGreaterThan.class); + s.add(BroadcastGreaterThanOrEqual.class); + s.add(BroadcastLessThan.class); + s.add(BroadcastLessThanOrEqual.class); + s.add(BroadcastNotEqual.class); + + //Redundant operations + s.add(ArgMax.class); //IMax already in namespace + s.add(ArgMin.class); //IMin already in namespace + + //Various utility methods, used internally + s.add(DynamicCustomOp.class); + s.add(ExternalErrorsFunction.class); + s.add(GradientBackwardsMarker.class); + s.add(KnnMinDistance.class); + s.add(BinaryRelativeError.class); + s.add(SpTreeCell.class); + s.add(BarnesHutGains.class); + s.add(BinaryMinimalRelativeError.class); + s.add(SkipGramRound.class); + s.add(BarnesHutSymmetrize.class); + s.add(BarnesEdgeForces.class); + s.add(CbowRound.class); + + //For TF compatibility only + s.add(NoOp.class); + s.add(RestoreV2.class); + s.add(ParallelConcat.class); + s.add(ParallelStack.class); + s.add(DeConv2DTF.class); + s.add(DeConv3DTF.class); + s.add(CompatSparseToDense.class); + s.add(CompatStringSplit.class); + s.add(ApplyGradientDescent.class); + s.add(RealDivOp.class); + s.add(SaveV2.class); + + //Control ops, used internally as part of loops etc + s.add(Enter.class); + s.add(Exit.class); + s.add(NextIteration.class); + s.add(LoopCond.class); + s.add(Merge.class); + s.add(Switch.class); + + //MetaOps, grid ops etc not part of public API + s.add(InvertedPredicateMetaOp.class); + s.add(PostulateMetaOp.class); + s.add(PredicateMetaOp.class); + s.add(ReduceMetaOp.class); + s.add(FreeGridOp.class); + + //Others that don't relaly make sense as a namespace method + s.add(CopyOp.class); + s.add(org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set.class); + s.add(PowPairwise.class); //We have custom op Pow already used for this + s.add(Create.class); //Already have zeros, ones, etc for this + s.add(HashCode.class); + s.add(ScalarSetValue.class); + s.add(PrintVariable.class); + s.add(PrintAffinity.class); + s.add(Assign.class); + + + + } @Test @Ignore public void generateOpClassList() throws Exception{