[WIP] Some nd4s tweaks (#68)

* Executioner fallback

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests for executioner

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-07-19 18:41:05 +03:00 committed by AlexDBlack
parent 2fb4a52a02
commit f29f19e9e9
2 changed files with 85 additions and 33 deletions

View File

@ -19,6 +19,7 @@ import java.util.{ List, Map, Properties }
import org.bytedeco.javacpp.Pointer import org.bytedeco.javacpp.Pointer
import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer } import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer }
import org.nd4j.linalg.api.environment.Nd4jEnvironment
import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics } import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics }
import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch } import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch }
import org.nd4j.linalg.api.ops._ import org.nd4j.linalg.api.ops._
@ -30,19 +31,33 @@ import org.nd4j.linalg.api.shape.{ LongShapeDescriptor, TadPack }
import org.nd4j.linalg.cache.TADManager import org.nd4j.linalg.cache.TADManager
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.profiler.ProfilerConfig import org.nd4j.linalg.profiler.ProfilerConfig
import org.slf4j.{ Logger, LoggerFactory }
object FunctionalOpExecutioner { object FunctionalOpExecutioner {
def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner() def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner()
} }
class FunctionalOpExecutioner extends OpExecutioner { class FunctionalOpExecutioner extends OpExecutioner {
def isVerbose: Boolean = ???
def log: Logger = LoggerFactory.getLogger(FunctionalOpExecutioner.getClass)
private[this] var verboseEnabled: Boolean = false
def isVerbose: Boolean = verboseEnabled
def enableVerboseMode(reallyEnable: Boolean): Unit =
verboseEnabled = reallyEnable
/** /**
* This method returns true if debug mode is enabled, false otherwise * This method returns true if debug mode is enabled, false otherwise
* *
* @return * @return
*/ */
def isDebug: Boolean = ??? private[this] var debugEnabled: Boolean = false
def isDebug: Boolean = debugEnabled
def enableDebugMode(reallyEnable: Boolean): Unit =
debugEnabled = reallyEnable
/** /**
* This method returns type for this executioner instance * This method returns type for this executioner instance
@ -112,7 +127,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* *
* @param op the operation to execute * @param op the operation to execute
*/ */
def execAndReturn(op: TransformOp): TransformOp = ??? def execAndReturn(op: TransformOp): TransformOp =
Nd4j.getExecutioner.execAndReturn(op)
/** /**
* Execute and return the result from an accumulation * Execute and return the result from an accumulation
@ -120,28 +136,33 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param op the operation to execute * @param op the operation to execute
* @return the accumulated result * @return the accumulated result
*/ */
def execAndReturn(op: ReduceOp): ReduceOp = ??? def execAndReturn(op: ReduceOp): ReduceOp =
Nd4j.getExecutioner.execAndReturn(op)
def execAndReturn(op: Variance): Variance = ??? def execAndReturn(op: Variance): Variance =
Nd4j.getExecutioner.execAndReturn(op)
/** Execute and return the result from an index accumulation /** Execute and return the result from an index accumulation
* *
* @param op the index accumulation operation to execute * @param op the index accumulation operation to execute
* @return the accumulated index * @return the accumulated index
*/ */
def execAndReturn(op: IndexAccumulation): IndexAccumulation = ??? def execAndReturn(op: IndexAccumulation): IndexAccumulation =
Nd4j.getExecutioner.execAndReturn(op)
/** Execute and return the result from a scalar op /** Execute and return the result from a scalar op
* *
* @param op the operation to execute * @param op the operation to execute
* @return the accumulated result * @return the accumulated result
*/ */
def execAndReturn(op: ScalarOp): ScalarOp = ??? def execAndReturn(op: ScalarOp): ScalarOp =
Nd4j.getExecutioner.execAndReturn(op)
/** Execute and return the result from a vector op /** Execute and return the result from a vector op
* *
* @param op */ * @param op */
def execAndReturn(op: BroadcastOp): BroadcastOp = ??? def execAndReturn(op: BroadcastOp): BroadcastOp =
Nd4j.getExecutioner.execAndReturn(op)
/** /**
* Execute a reduceOp, possibly along one or more dimensions * Execute a reduceOp, possibly along one or more dimensions
@ -149,7 +170,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param reduceOp the reduceOp * @param reduceOp the reduceOp
* @return the reduceOp op * @return the reduceOp op
*/ */
def exec(reduceOp: ReduceOp): INDArray = ??? def exec(reduceOp: ReduceOp): INDArray =
Nd4j.getExecutioner.exec(reduceOp)
/** /**
* Execute a broadcast op, possibly along one or more dimensions * Execute a broadcast op, possibly along one or more dimensions
@ -157,7 +179,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param broadcast the accumulation * @param broadcast the accumulation
* @return the broadcast op * @return the broadcast op
*/ */
def exec(broadcast: BroadcastOp): INDArray = ??? def exec(broadcast: BroadcastOp): INDArray =
Nd4j.getExecutioner.exec(broadcast)
/** /**
* Execute ScalarOp * Execute ScalarOp
@ -165,7 +188,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param broadcast * @param broadcast
* @return * @return
*/ */
def exec(broadcast: ScalarOp): INDArray = ??? def exec(broadcast: ScalarOp): INDArray =
Nd4j.exec(broadcast)
/** /**
* Execute an variance accumulation op, possibly along one or more dimensions * Execute an variance accumulation op, possibly along one or more dimensions
@ -173,14 +197,16 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param accumulation the accumulation * @param accumulation the accumulation
* @return the accmulation op * @return the accmulation op
*/ */
def exec(accumulation: Variance): INDArray = ??? def exec(accumulation: Variance): INDArray =
Nd4j.getExecutioner.exec(accumulation)
/** Execute an index accumulation along one or more dimensions /** Execute an index accumulation along one or more dimensions
* *
* @param indexAccum the index accumulation operation * @param indexAccum the index accumulation operation
* @return result * @return result
*/ */
def exec(indexAccum: IndexAccumulation): INDArray = ??? def exec(indexAccum: IndexAccumulation): INDArray =
Nd4j.getExecutioner.exec(indexAccum)
/** /**
* *
@ -190,34 +216,39 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param op the operation to execute * @param op the operation to execute
* @return the result from the operation * @return the result from the operation
*/ */
def execAndReturn(op: Op): Op = ??? def execAndReturn(op: Op): Op =
Nd4j.getExecutioner.execAndReturn(op)
/** /**
* Execute MetaOp * Execute MetaOp
* *
* @param op * @param op
*/ */
def exec(op: MetaOp): Unit = ??? def exec(op: MetaOp): Unit =
Nd4j.getExecutioner.exec(op)
/** /**
* Execute GridOp * Execute GridOp
* *
* @param op * @param op
*/ */
def exec(op: GridOp): Unit = ??? def exec(op: GridOp): Unit =
Nd4j.getExecutioner.exec(op)
/** /**
* *
* @param op * @param op
*/ */
def exec(op: Aggregate): Unit = ??? def exec(op: Aggregate): Unit =
Nd4j.getExecutioner.exec(op)
/** /**
* This method executes previously built batch * This method executes previously built batch
* *
* @param batch * @param batch
*/ */
def exec[T <: Aggregate](batch: Batch[T]): Unit = ??? def exec[T <: Aggregate](batch: Batch[T]): Unit =
Nd4j.getExecutioner.exec(batch)
/** /**
* This method takes arbitrary sized list of aggregates, * This method takes arbitrary sized list of aggregates,
@ -225,14 +256,16 @@ class FunctionalOpExecutioner extends OpExecutioner {
* *
* @param batch * @param batch
*/ */
def exec(batch: java.util.List[Aggregate]): Unit = ??? def exec(batch: java.util.List[Aggregate]): Unit =
Nd4j.getExecutioner.exec(batch)
/** /**
* This method executes specified RandomOp using default RNG available via Nd4j.getRandom() * This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
* *
* @param op * @param op
*/ */
def exec(op: RandomOp): INDArray = ??? def exec(op: RandomOp): INDArray =
Nd4j.getExecutioner.exec(op)
/** /**
* This method executes specific RandomOp against specified RNG * This method executes specific RandomOp against specified RNG
@ -240,7 +273,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param op * @param op
* @param rng * @param rng
*/ */
def exec(op: RandomOp, rng: Random): INDArray = ??? def exec(op: RandomOp, rng: Random): INDArray =
Nd4j.getExecutioner.exec(op, rng)
/** /**
* This method return set of key/value and * This method return set of key/value and
@ -249,7 +283,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* *
* @return * @return
*/ */
def getEnvironmentInformation: Properties = ??? def getEnvironmentInformation: Properties =
Nd4j.getExecutioner.getEnvironmentInformation
/** /**
* This method specifies desired profiling mode * This method specifies desired profiling mode
@ -263,7 +298,8 @@ class FunctionalOpExecutioner extends OpExecutioner {
* *
* @param config * @param config
*/ */
def setProfilingConfig(config: ProfilerConfig): Unit = ??? def setProfilingConfig(config: ProfilerConfig): Unit =
Nd4j.getExecutioner.setProfilingConfig(config)
/** /**
* Ths method returns current profiling * Ths method returns current profiling
@ -277,12 +313,14 @@ class FunctionalOpExecutioner extends OpExecutioner {
* *
* @return * @return
*/ */
def getTADManager: TADManager = ??? def getTADManager: TADManager =
Nd4j.getExecutioner.getTADManager
/** /**
* This method prints out environmental information returned by getEnvironmentInformation() method * This method prints out environmental information returned by getEnvironmentInformation() method
*/ */
def printEnvironmentInformation(): Unit = ??? def printEnvironmentInformation(): Unit =
Nd4j.getExecutioner.printEnvironmentInformation()
/** /**
* This method ensures all operations that supposed to be executed at this moment, are executed. * This method ensures all operations that supposed to be executed at this moment, are executed.
@ -364,20 +402,19 @@ class FunctionalOpExecutioner extends OpExecutioner {
* @param context * @param context
* @return method returns output arrays defined within context * @return method returns output arrays defined within context
*/ */
def exec(op: CustomOp, context: OpContext): Array[INDArray] = ??? def exec(op: CustomOp, context: OpContext): Array[INDArray] =
Nd4j.getExecutioner.exec(op, context)
def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = ??? def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] =
Nd4j.getExecutioner.calculateOutputShape(op)
/** /**
* Equivalent to calli * Equivalent to calli
*/ */
def allocateOutputArrays(op: CustomOp): Array[INDArray] = ??? def allocateOutputArrays(op: CustomOp): Array[INDArray] =
Nd4j.getExecutioner.allocateOutputArrays(op)
def enableDebugMode(reallyEnable: Boolean): Unit = ??? def isExperimentalMode: Boolean = true
def enableVerboseMode(reallyEnable: Boolean): Unit = ???
def isExperimentalMode: Boolean = ???
def registerGraph(id: Long, graph: Pointer): Unit = ??? def registerGraph(id: Long, graph: Pointer): Unit = ???

View File

@ -17,6 +17,7 @@ package org.nd4s
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4s.Implicits._ import org.nd4s.Implicits._
import org.nd4s.ops.FunctionalOpExecutioner
import org.scalatest.{ FlatSpec, Matchers } import org.scalatest.{ FlatSpec, Matchers }
class NDArrayCollectionAPITest extends FlatSpec with Matchers { class NDArrayCollectionAPITest extends FlatSpec with Matchers {
@ -186,4 +187,18 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers {
assert(false == results) assert(false == results)
} }
"FunctionalOpExecutioner" should "allow debug and verbose" in {
val executioner = new FunctionalOpExecutioner
executioner.enableDebugMode(true)
executioner.enableVerboseMode(true)
assert(executioner.isDebug)
assert(executioner.isVerbose)
}
"FunctionalOpExecutioner" should "provide access to environment information" in {
FunctionalOpExecutioner.apply.printEnvironmentInformation()
val environment = FunctionalOpExecutioner.apply.getEnvironmentInformation
assert(environment != null)
}
} }