[WIP] Some nd4s tweaks (#68)
* Executioner fallback Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests for executioner Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
This commit is contained in:
		
							parent
							
								
									2fb4a52a02
								
							
						
					
					
						commit
						f29f19e9e9
					
				| @ -19,6 +19,7 @@ import java.util.{ List, Map, Properties } | |||||||
| 
 | 
 | ||||||
| import org.bytedeco.javacpp.Pointer | import org.bytedeco.javacpp.Pointer | ||||||
| import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer } | import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer } | ||||||
|  | import org.nd4j.linalg.api.environment.Nd4jEnvironment | ||||||
| import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics } | import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics } | ||||||
| import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch } | import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch } | ||||||
| import org.nd4j.linalg.api.ops._ | import org.nd4j.linalg.api.ops._ | ||||||
| @ -30,19 +31,33 @@ import org.nd4j.linalg.api.shape.{ LongShapeDescriptor, TadPack } | |||||||
| import org.nd4j.linalg.cache.TADManager | import org.nd4j.linalg.cache.TADManager | ||||||
| import org.nd4j.linalg.factory.Nd4j | import org.nd4j.linalg.factory.Nd4j | ||||||
| import org.nd4j.linalg.profiler.ProfilerConfig | import org.nd4j.linalg.profiler.ProfilerConfig | ||||||
|  | import org.slf4j.{ Logger, LoggerFactory } | ||||||
| 
 | 
 | ||||||
| object FunctionalOpExecutioner { | object FunctionalOpExecutioner { | ||||||
|   def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner() |   def apply: FunctionalOpExecutioner = new FunctionalOpExecutioner() | ||||||
| } | } | ||||||
| class FunctionalOpExecutioner extends OpExecutioner { | class FunctionalOpExecutioner extends OpExecutioner { | ||||||
|   def isVerbose: Boolean = ??? | 
 | ||||||
|  |   def log: Logger = LoggerFactory.getLogger(FunctionalOpExecutioner.getClass) | ||||||
|  | 
 | ||||||
|  |   private[this] var verboseEnabled: Boolean = false | ||||||
|  | 
 | ||||||
|  |   def isVerbose: Boolean = verboseEnabled | ||||||
|  | 
 | ||||||
|  |   def enableVerboseMode(reallyEnable: Boolean): Unit = | ||||||
|  |     verboseEnabled = reallyEnable | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method returns true if debug mode is enabled, false otherwise |     * This method returns true if debug mode is enabled, false otherwise | ||||||
|     * |     * | ||||||
|     * @return |     * @return | ||||||
|     */ |     */ | ||||||
|   def isDebug: Boolean = ??? |   private[this] var debugEnabled: Boolean = false | ||||||
|  | 
 | ||||||
|  |   def isDebug: Boolean = debugEnabled | ||||||
|  | 
 | ||||||
|  |   def enableDebugMode(reallyEnable: Boolean): Unit = | ||||||
|  |     debugEnabled = reallyEnable | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method returns type for this executioner instance |     * This method returns type for this executioner instance | ||||||
| @ -112,7 +127,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * |     * | ||||||
|     * @param op the operation to execute |     * @param op the operation to execute | ||||||
|     */ |     */ | ||||||
|   def execAndReturn(op: TransformOp): TransformOp = ??? |   def execAndReturn(op: TransformOp): TransformOp = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute and return the result from an accumulation |     * Execute and return the result from an accumulation | ||||||
| @ -120,28 +136,33 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param op the operation to execute |     * @param op the operation to execute | ||||||
|     * @return the accumulated result |     * @return the accumulated result | ||||||
|     */ |     */ | ||||||
|   def execAndReturn(op: ReduceOp): ReduceOp = ??? |   def execAndReturn(op: ReduceOp): ReduceOp = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   def execAndReturn(op: Variance): Variance = ??? |   def execAndReturn(op: Variance): Variance = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   /** Execute and return the result from an index accumulation |   /** Execute and return the result from an index accumulation | ||||||
|     * |     * | ||||||
|     * @param op the index accumulation operation to execute |     * @param op the index accumulation operation to execute | ||||||
|     * @return the accumulated index |     * @return the accumulated index | ||||||
|     */ |     */ | ||||||
|   def execAndReturn(op: IndexAccumulation): IndexAccumulation = ??? |   def execAndReturn(op: IndexAccumulation): IndexAccumulation = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   /** Execute and return the result from a scalar op |   /** Execute and return the result from a scalar op | ||||||
|     * |     * | ||||||
|     * @param op the operation to execute |     * @param op the operation to execute | ||||||
|     * @return the accumulated result |     * @return the accumulated result | ||||||
|     */ |     */ | ||||||
|   def execAndReturn(op: ScalarOp): ScalarOp = ??? |   def execAndReturn(op: ScalarOp): ScalarOp = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   /** Execute and return the result from a vector op |   /** Execute and return the result from a vector op | ||||||
|     * |     * | ||||||
|     * @param op */ |     * @param op */ | ||||||
|   def execAndReturn(op: BroadcastOp): BroadcastOp = ??? |   def execAndReturn(op: BroadcastOp): BroadcastOp = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute a reduceOp, possibly along one or more dimensions |     * Execute a reduceOp, possibly along one or more dimensions | ||||||
| @ -149,7 +170,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param reduceOp the reduceOp |     * @param reduceOp the reduceOp | ||||||
|     * @return the reduceOp op |     * @return the reduceOp op | ||||||
|     */ |     */ | ||||||
|   def exec(reduceOp: ReduceOp): INDArray = ??? |   def exec(reduceOp: ReduceOp): INDArray = | ||||||
|  |     Nd4j.getExecutioner.exec(reduceOp) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute a broadcast op, possibly along one or more dimensions |     * Execute a broadcast op, possibly along one or more dimensions | ||||||
| @ -157,7 +179,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param broadcast the accumulation |     * @param broadcast the accumulation | ||||||
|     * @return the broadcast op |     * @return the broadcast op | ||||||
|     */ |     */ | ||||||
|   def exec(broadcast: BroadcastOp): INDArray = ??? |   def exec(broadcast: BroadcastOp): INDArray = | ||||||
|  |     Nd4j.getExecutioner.exec(broadcast) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute ScalarOp |     * Execute ScalarOp | ||||||
| @ -165,7 +188,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param broadcast |     * @param broadcast | ||||||
|     * @return |     * @return | ||||||
|     */ |     */ | ||||||
|   def exec(broadcast: ScalarOp): INDArray = ??? |   def exec(broadcast: ScalarOp): INDArray = | ||||||
|  |     Nd4j.exec(broadcast) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute an variance accumulation op, possibly along one or more dimensions |     * Execute an variance accumulation op, possibly along one or more dimensions | ||||||
| @ -173,14 +197,16 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param accumulation the accumulation |     * @param accumulation the accumulation | ||||||
|     * @return the accmulation op |     * @return the accmulation op | ||||||
|     */ |     */ | ||||||
|   def exec(accumulation: Variance): INDArray = ??? |   def exec(accumulation: Variance): INDArray = | ||||||
|  |     Nd4j.getExecutioner.exec(accumulation) | ||||||
| 
 | 
 | ||||||
|   /** Execute an index accumulation along one or more dimensions |   /** Execute an index accumulation along one or more dimensions | ||||||
|     * |     * | ||||||
|     * @param indexAccum the index accumulation operation |     * @param indexAccum the index accumulation operation | ||||||
|     * @return result |     * @return result | ||||||
|     */ |     */ | ||||||
|   def exec(indexAccum: IndexAccumulation): INDArray = ??? |   def exec(indexAccum: IndexAccumulation): INDArray = | ||||||
|  |     Nd4j.getExecutioner.exec(indexAccum) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * |     * | ||||||
| @ -190,34 +216,39 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param op the operation to execute |     * @param op the operation to execute | ||||||
|     * @return the result from the operation |     * @return the result from the operation | ||||||
|     */ |     */ | ||||||
|   def execAndReturn(op: Op): Op = ??? |   def execAndReturn(op: Op): Op = | ||||||
|  |     Nd4j.getExecutioner.execAndReturn(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute MetaOp |     * Execute MetaOp | ||||||
|     * |     * | ||||||
|     * @param op |     * @param op | ||||||
|     */ |     */ | ||||||
|   def exec(op: MetaOp): Unit = ??? |   def exec(op: MetaOp): Unit = | ||||||
|  |     Nd4j.getExecutioner.exec(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Execute GridOp |     * Execute GridOp | ||||||
|     * |     * | ||||||
|     * @param op |     * @param op | ||||||
|     */ |     */ | ||||||
|   def exec(op: GridOp): Unit = ??? |   def exec(op: GridOp): Unit = | ||||||
|  |     Nd4j.getExecutioner.exec(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * |     * | ||||||
|     * @param op |     * @param op | ||||||
|     */ |     */ | ||||||
|   def exec(op: Aggregate): Unit = ??? |   def exec(op: Aggregate): Unit = | ||||||
|  |     Nd4j.getExecutioner.exec(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method executes previously built batch |     * This method executes previously built batch | ||||||
|     * |     * | ||||||
|     * @param batch |     * @param batch | ||||||
|     */ |     */ | ||||||
|   def exec[T <: Aggregate](batch: Batch[T]): Unit = ??? |   def exec[T <: Aggregate](batch: Batch[T]): Unit = | ||||||
|  |     Nd4j.getExecutioner.exec(batch) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method takes arbitrary sized list of aggregates, |     * This method takes arbitrary sized list of aggregates, | ||||||
| @ -225,14 +256,16 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * |     * | ||||||
|     * @param batch |     * @param batch | ||||||
|     */ |     */ | ||||||
|   def exec(batch: java.util.List[Aggregate]): Unit = ??? |   def exec(batch: java.util.List[Aggregate]): Unit = | ||||||
|  |     Nd4j.getExecutioner.exec(batch) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method executes specified RandomOp using default RNG available via Nd4j.getRandom() |     * This method executes specified RandomOp using default RNG available via Nd4j.getRandom() | ||||||
|     * |     * | ||||||
|     * @param op |     * @param op | ||||||
|     */ |     */ | ||||||
|   def exec(op: RandomOp): INDArray = ??? |   def exec(op: RandomOp): INDArray = | ||||||
|  |     Nd4j.getExecutioner.exec(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method executes specific RandomOp against specified RNG |     * This method executes specific RandomOp against specified RNG | ||||||
| @ -240,7 +273,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param op |     * @param op | ||||||
|     * @param rng |     * @param rng | ||||||
|     */ |     */ | ||||||
|   def exec(op: RandomOp, rng: Random): INDArray = ??? |   def exec(op: RandomOp, rng: Random): INDArray = | ||||||
|  |     Nd4j.getExecutioner.exec(op, rng) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method return set of key/value and |     * This method return set of key/value and | ||||||
| @ -249,7 +283,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * |     * | ||||||
|     * @return |     * @return | ||||||
|     */ |     */ | ||||||
|   def getEnvironmentInformation: Properties = ??? |   def getEnvironmentInformation: Properties = | ||||||
|  |     Nd4j.getExecutioner.getEnvironmentInformation | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method specifies desired profiling mode |     * This method specifies desired profiling mode | ||||||
| @ -263,7 +298,8 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * |     * | ||||||
|     * @param config |     * @param config | ||||||
|     */ |     */ | ||||||
|   def setProfilingConfig(config: ProfilerConfig): Unit = ??? |   def setProfilingConfig(config: ProfilerConfig): Unit = | ||||||
|  |     Nd4j.getExecutioner.setProfilingConfig(config) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Ths method returns current profiling |     * Ths method returns current profiling | ||||||
| @ -277,12 +313,14 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * |     * | ||||||
|     * @return |     * @return | ||||||
|     */ |     */ | ||||||
|   def getTADManager: TADManager = ??? |   def getTADManager: TADManager = | ||||||
|  |     Nd4j.getExecutioner.getTADManager | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method prints out environmental information returned by getEnvironmentInformation() method |     * This method prints out environmental information returned by getEnvironmentInformation() method | ||||||
|     */ |     */ | ||||||
|   def printEnvironmentInformation(): Unit = ??? |   def printEnvironmentInformation(): Unit = | ||||||
|  |     Nd4j.getExecutioner.printEnvironmentInformation() | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * This method ensures all operations that supposed to be executed at this moment, are executed. |     * This method ensures all operations that supposed to be executed at this moment, are executed. | ||||||
| @ -364,20 +402,19 @@ class FunctionalOpExecutioner extends OpExecutioner { | |||||||
|     * @param context |     * @param context | ||||||
|     * @return method returns output arrays defined within context |     * @return method returns output arrays defined within context | ||||||
|     */ |     */ | ||||||
|   def exec(op: CustomOp, context: OpContext): Array[INDArray] = ??? |   def exec(op: CustomOp, context: OpContext): Array[INDArray] = | ||||||
|  |     Nd4j.getExecutioner.exec(op, context) | ||||||
| 
 | 
 | ||||||
|   def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = ??? |   def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = | ||||||
|  |     Nd4j.getExecutioner.calculateOutputShape(op) | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|     * Equivalent to calli |     * Equivalent to calli | ||||||
|     */ |     */ | ||||||
|   def allocateOutputArrays(op: CustomOp): Array[INDArray] = ??? |   def allocateOutputArrays(op: CustomOp): Array[INDArray] = | ||||||
|  |     Nd4j.getExecutioner.allocateOutputArrays(op) | ||||||
| 
 | 
 | ||||||
|   def enableDebugMode(reallyEnable: Boolean): Unit = ??? |   def isExperimentalMode: Boolean = true | ||||||
| 
 |  | ||||||
|   def enableVerboseMode(reallyEnable: Boolean): Unit = ??? |  | ||||||
| 
 |  | ||||||
|   def isExperimentalMode: Boolean = ??? |  | ||||||
| 
 | 
 | ||||||
|   def registerGraph(id: Long, graph: Pointer): Unit = ??? |   def registerGraph(id: Long, graph: Pointer): Unit = ??? | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -17,6 +17,7 @@ package org.nd4s | |||||||
| 
 | 
 | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray | import org.nd4j.linalg.api.ndarray.INDArray | ||||||
| import org.nd4s.Implicits._ | import org.nd4s.Implicits._ | ||||||
|  | import org.nd4s.ops.FunctionalOpExecutioner | ||||||
| import org.scalatest.{ FlatSpec, Matchers } | import org.scalatest.{ FlatSpec, Matchers } | ||||||
| 
 | 
 | ||||||
| class NDArrayCollectionAPITest extends FlatSpec with Matchers { | class NDArrayCollectionAPITest extends FlatSpec with Matchers { | ||||||
| @ -186,4 +187,18 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers { | |||||||
|     assert(false == results) |     assert(false == results) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   "FunctionalOpExecutioner" should "allow debug and verbose" in { | ||||||
|  |     val executioner = new FunctionalOpExecutioner | ||||||
|  |     executioner.enableDebugMode(true) | ||||||
|  |     executioner.enableVerboseMode(true) | ||||||
|  | 
 | ||||||
|  |     assert(executioner.isDebug) | ||||||
|  |     assert(executioner.isVerbose) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   "FunctionalOpExecutioner" should "provide access to environment information" in { | ||||||
|  |     FunctionalOpExecutioner.apply.printEnvironmentInformation() | ||||||
|  |     val environment = FunctionalOpExecutioner.apply.getEnvironmentInformation | ||||||
|  |     assert(environment != null) | ||||||
|  |   } | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user