From f29f19e9e9858287bdf73472a95da84181cc2617 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Fri, 19 Jul 2019 18:41:05 +0300 Subject: [PATCH] [WIP] Some nd4s tweaks (#68) * Executioner fallback Signed-off-by: Alexander Stoyakin * Tests for executioner Signed-off-by: Alexander Stoyakin --- .../nd4s/ops/FunctionalOpExecutioner.scala | 103 ++++++++++++------ .../org/nd4s/NDArrayCollectionAPITest.scala | 15 +++ 2 files changed, 85 insertions(+), 33 deletions(-) diff --git a/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala b/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala index 5ead2e43d..826264b8f 100644 --- a/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala +++ b/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala @@ -19,6 +19,7 @@ import java.util.{ List, Map, Properties } import org.bytedeco.javacpp.Pointer import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer } +import org.nd4j.linalg.api.environment.Nd4jEnvironment import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics } import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch } import org.nd4j.linalg.api.ops._ @@ -30,19 +31,33 @@ import org.nd4j.linalg.api.shape.{ LongShapeDescriptor, TadPack } import org.nd4j.linalg.cache.TADManager import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.profiler.ProfilerConfig +import org.slf4j.{ Logger, LoggerFactory } object FunctionalOpExecutioner { def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner() } class FunctionalOpExecutioner extends OpExecutioner { - def isVerbose: Boolean = ??? + + def log: Logger = LoggerFactory.getLogger(FunctionalOpExecutioner.getClass) + + private[this] var verboseEnabled: Boolean = false + + def isVerbose: Boolean = verboseEnabled + + def enableVerboseMode(reallyEnable: Boolean): Unit = + verboseEnabled = reallyEnable /** * This method returns true if debug mode is enabled, false otherwise * * @return */ - def isDebug: Boolean = ??? + private[this] var debugEnabled: Boolean = false + + def isDebug: Boolean = debugEnabled + + def enableDebugMode(reallyEnable: Boolean): Unit = + debugEnabled = reallyEnable /** * This method returns type for this executioner instance @@ -112,7 +127,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * * @param op the operation to execute */ - def execAndReturn(op: TransformOp): TransformOp = ??? + def execAndReturn(op: TransformOp): TransformOp = + Nd4j.getExecutioner.execAndReturn(op) /** * Execute and return the result from an accumulation @@ -120,28 +136,33 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param op the operation to execute * @return the accumulated result */ - def execAndReturn(op: ReduceOp): ReduceOp = ??? + def execAndReturn(op: ReduceOp): ReduceOp = + Nd4j.getExecutioner.execAndReturn(op) - def execAndReturn(op: Variance): Variance = ??? + def execAndReturn(op: Variance): Variance = + Nd4j.getExecutioner.execAndReturn(op) /** Execute and return the result from an index accumulation * * @param op the index accumulation operation to execute * @return the accumulated index */ - def execAndReturn(op: IndexAccumulation): IndexAccumulation = ??? + def execAndReturn(op: IndexAccumulation): IndexAccumulation = + Nd4j.getExecutioner.execAndReturn(op) /** Execute and return the result from a scalar op * * @param op the operation to execute * @return the accumulated result */ - def execAndReturn(op: ScalarOp): ScalarOp = ??? + def execAndReturn(op: ScalarOp): ScalarOp = + Nd4j.getExecutioner.execAndReturn(op) /** Execute and return the result from a vector op * * @param op */ - def execAndReturn(op: BroadcastOp): BroadcastOp = ??? + def execAndReturn(op: BroadcastOp): BroadcastOp = + Nd4j.getExecutioner.execAndReturn(op) /** * Execute a reduceOp, possibly along one or more dimensions @@ -149,7 +170,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param reduceOp the reduceOp * @return the reduceOp op */ - def exec(reduceOp: ReduceOp): INDArray = ??? + def exec(reduceOp: ReduceOp): INDArray = + Nd4j.getExecutioner.exec(reduceOp) /** * Execute a broadcast op, possibly along one or more dimensions @@ -157,7 +179,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param broadcast the accumulation * @return the broadcast op */ - def exec(broadcast: BroadcastOp): INDArray = ??? + def exec(broadcast: BroadcastOp): INDArray = + Nd4j.getExecutioner.exec(broadcast) /** * Execute ScalarOp @@ -165,7 +188,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param broadcast * @return */ - def exec(broadcast: ScalarOp): INDArray = ??? + def exec(broadcast: ScalarOp): INDArray = + Nd4j.exec(broadcast) /** * Execute an variance accumulation op, possibly along one or more dimensions @@ -173,14 +197,16 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param accumulation the accumulation * @return the accmulation op */ - def exec(accumulation: Variance): INDArray = ??? + def exec(accumulation: Variance): INDArray = + Nd4j.getExecutioner.exec(accumulation) /** Execute an index accumulation along one or more dimensions * * @param indexAccum the index accumulation operation * @return result */ - def exec(indexAccum: IndexAccumulation): INDArray = ??? + def exec(indexAccum: IndexAccumulation): INDArray = + Nd4j.getExecutioner.exec(indexAccum) /** * @@ -190,34 +216,39 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param op the operation to execute * @return the result from the operation */ - def execAndReturn(op: Op): Op = ??? + def execAndReturn(op: Op): Op = + Nd4j.getExecutioner.execAndReturn(op) /** * Execute MetaOp * * @param op */ - def exec(op: MetaOp): Unit = ??? + def exec(op: MetaOp): Unit = + Nd4j.getExecutioner.exec(op) /** * Execute GridOp * * @param op */ - def exec(op: GridOp): Unit = ??? + def exec(op: GridOp): Unit = + Nd4j.getExecutioner.exec(op) /** * * @param op */ - def exec(op: Aggregate): Unit = ??? + def exec(op: Aggregate): Unit = + Nd4j.getExecutioner.exec(op) /** * This method executes previously built batch * * @param batch */ - def exec[T <: Aggregate](batch: Batch[T]): Unit = ??? + def exec[T <: Aggregate](batch: Batch[T]): Unit = + Nd4j.getExecutioner.exec(batch) /** * This method takes arbitrary sized list of aggregates, @@ -225,14 +256,16 @@ class FunctionalOpExecutioner extends OpExecutioner { * * @param batch */ - def exec(batch: java.util.List[Aggregate]): Unit = ??? + def exec(batch: java.util.List[Aggregate]): Unit = + Nd4j.getExecutioner.exec(batch) /** * This method executes specified RandomOp using default RNG available via Nd4j.getRandom() * * @param op */ - def exec(op: RandomOp): INDArray = ??? + def exec(op: RandomOp): INDArray = + Nd4j.getExecutioner.exec(op) /** * This method executes specific RandomOp against specified RNG @@ -240,7 +273,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param op * @param rng */ - def exec(op: RandomOp, rng: Random): INDArray = ??? + def exec(op: RandomOp, rng: Random): INDArray = + Nd4j.getExecutioner.exec(op, rng) /** * This method return set of key/value and @@ -249,7 +283,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * * @return */ - def getEnvironmentInformation: Properties = ??? + def getEnvironmentInformation: Properties = + Nd4j.getExecutioner.getEnvironmentInformation /** * This method specifies desired profiling mode @@ -263,7 +298,8 @@ class FunctionalOpExecutioner extends OpExecutioner { * * @param config */ - def setProfilingConfig(config: ProfilerConfig): Unit = ??? + def setProfilingConfig(config: ProfilerConfig): Unit = + Nd4j.getExecutioner.setProfilingConfig(config) /** * Ths method returns current profiling @@ -277,12 +313,14 @@ class FunctionalOpExecutioner extends OpExecutioner { * * @return */ - def getTADManager: TADManager = ??? + def getTADManager: TADManager = + Nd4j.getExecutioner.getTADManager /** * This method prints out environmental information returned by getEnvironmentInformation() method */ - def printEnvironmentInformation(): Unit = ??? + def printEnvironmentInformation(): Unit = + Nd4j.getExecutioner.printEnvironmentInformation() /** * This method ensures all operations that supposed to be executed at this moment, are executed. @@ -364,20 +402,19 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param context * @return method returns output arrays defined within context */ - def exec(op: CustomOp, context: OpContext): Array[INDArray] = ??? + def exec(op: CustomOp, context: OpContext): Array[INDArray] = + Nd4j.getExecutioner.exec(op, context) - def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = ??? + def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = + Nd4j.getExecutioner.calculateOutputShape(op) /** * Equivalent to calli */ - def allocateOutputArrays(op: CustomOp): Array[INDArray] = ??? + def allocateOutputArrays(op: CustomOp): Array[INDArray] = + Nd4j.getExecutioner.allocateOutputArrays(op) - def enableDebugMode(reallyEnable: Boolean): Unit = ??? - - def enableVerboseMode(reallyEnable: Boolean): Unit = ??? - - def isExperimentalMode: Boolean = ??? + def isExperimentalMode: Boolean = true def registerGraph(id: Long, graph: Pointer): Unit = ??? diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayCollectionAPITest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayCollectionAPITest.scala index ec7489e49..e250d990f 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayCollectionAPITest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayCollectionAPITest.scala @@ -17,6 +17,7 @@ package org.nd4s import org.nd4j.linalg.api.ndarray.INDArray import org.nd4s.Implicits._ +import org.nd4s.ops.FunctionalOpExecutioner import org.scalatest.{ FlatSpec, Matchers } class NDArrayCollectionAPITest extends FlatSpec with Matchers { @@ -186,4 +187,18 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers { assert(false == results) } + "FunctionalOpExecutioner" should "allow debug and verbose" in { + val executioner = new FunctionalOpExecutioner + executioner.enableDebugMode(true) + executioner.enableVerboseMode(true) + + assert(executioner.isDebug) + assert(executioner.isVerbose) + } + + "FunctionalOpExecutioner" should "provide access to environment information" in { + FunctionalOpExecutioner.apply.printEnvironmentInformation() + val environment = FunctionalOpExecutioner.apply.getEnvironmentInformation + assert(environment != null) + } }