Indexing demo
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
This commit is contained in:
		
							parent
							
								
									b8f2a83a5a
								
							
						
					
					
						commit
						7a1e327516
					
				| @ -44,6 +44,9 @@ class SameDiffWrapper { | ||||
|   def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable = | ||||
|     sd.`var`(name, dataType, shape: _*) | ||||
| 
 | ||||
|   def bind(data: INDArray): SDVariable = | ||||
|     sd.`var`("", data) | ||||
| 
 | ||||
|   def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable = | ||||
|     sd.`var`(name, dataType, shape: _*) | ||||
| 
 | ||||
| @ -63,6 +66,16 @@ class SDVariableWrapper { | ||||
| 
 | ||||
|   def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) | ||||
| 
 | ||||
|   def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) | ||||
| 
 | ||||
|   def apply(x: SDIndex)(y: SDIndex): SDVariable = | ||||
|     (x, y) match { | ||||
|       case (_, y) => thisVariable.get(SDIndex.all(), y) | ||||
|       case (x, _) => thisVariable.get(x, SDIndex.all()) | ||||
|       case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) | ||||
|       case (x, y) => thisVariable.get(x, y) | ||||
|     } | ||||
| 
 | ||||
|   def add(other: Double): Unit = thisVariable.add(other) | ||||
| 
 | ||||
|   def *(other: SDVariable): SDVariable = | ||||
|  | ||||
| @ -205,9 +205,18 @@ class MathTest extends FlatSpec with Matchers { | ||||
|     implicit val sd = SameDiff.create | ||||
| 
 | ||||
|     val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L) | ||||
|     val x = sd.`var`(arr) | ||||
|     val x = sd.bind(arr) | ||||
|     val y = new SDVariableWrapper(x) | ||||
| 
 | ||||
|     x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr | ||||
|   } | ||||
| 
 | ||||
|   "SDVariable " should "be indexable in 2d" in { | ||||
|     implicit val sd = SameDiff.create | ||||
| 
 | ||||
|     val arr = Nd4j.rand(2, 5) | ||||
|     val x = sd.bind(arr) | ||||
| 
 | ||||
|     println(x(SDIndex.point(0), _: SDIndex).getArr) | ||||
|   } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user