[WIP] nd4s - Scala operators for SameDiff (#113)
* Jar packaging for maven Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Typo fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * minimal viable prototype for SD Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests corrected Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * few fixes for bfloat16 in java and cpp (#114) Signed-off-by: raver119 <raver119@gmail.com> * Nd4j refactoring (#112) * refactoring Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip * fix: make test public. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * make test public. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fixes read refactoring. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Enabled test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test copied from nd4j Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * [WIP] bitwise ops (#115) * - cyclic_shift_bits + test - shift_bits + test Signed-off-by: raver119 <raver119@gmail.com> * OMP_IF replacement Signed-off-by: raver119 <raver119@gmail.com> * Thin wrapper added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Shugeo cuda tests (#116) * Added tests for get_seed/set_seed ops. * Added missed tests for scatter_sub/mul/div ops. * Added tests for hardsigmoid and hardsigmoid_bp. * Added tests for hardtanh and hardtanh_bp ops. * Added test for histogram op. * Added tests for identity op. * Refactored mergemaxindex op. Added tests for log1p,mergemaxindex, mod and mod_bp ops. * Fixed tests for FloorDiv. * Added test for rank op. * Added tests for rationaltanh/rationaltanh_bp ops. * Added tests for realdiv/realdiv_bp. * Added tests for rectifiedtanh/_bp ops. * Added tests for shapes_of op. * Added tests for shapes_of op. * Added tests for size op. * Added tests for softplus/_bp ops. * Added tests for softsign/_bp ops. * Added tests for toggle_bits op. Fixed processing of OP_IMPL and so on defititions. * Added test for truncatediv op. * Added another test for truncatediv op. * Added another test for histogram. * Added tests for unstack_list op. * Refactored to_int32/uint32/float16/float32/double/int64/uint64 ops and tests. * Refactored mergemaxindex op helper for cuda platform and tests. * Fixed cuda kernel for histogram op helper. * Refactor skipgram to avoid early buffers shift. * Fixed check up with non_max_suppression op cuda helper. Added cuda kernel implementation for skipgram op helpers. * Added implementation of skipgram op helper for cuda platform. Working revision * Fixed mergeMaxIndex kernel and move it to separate source file. * Adding arithmetic Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Eliminated memory leaks and dropped waste prints with tests. (#117) * Added tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * fix test Signed-off-by: raver119 <raver119@gmail.com> * no openmp for ClipByGlobalNorm Signed-off-by: raver119 <raver119@gmail.com> * Stubs for ops Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * [WIP] right shift ops (#118) * right shift ops Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * rotr test Signed-off-by: raver119 <raver119@gmail.com> * fix: IOException no longer thrown by read(). (#120) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Small fix in TensorflowConversion class (#121) Signed-off-by: Alex Black <blacka101@gmail.com> * Shyrma concat2 (#119) * - rewrite/improve concat Signed-off-by: Yurii <yurii@skymind.io> * - ged rid of unnecessary argument in concat kernel Signed-off-by: Yurii <yurii@skymind.io> * InferenceSession additional validation for shape calc (#122) Signed-off-by: Alex Black <blacka101@gmail.com> * [WIP] build fix (#124) * AffinityManager changes Signed-off-by: raver119 <raver119@gmail.com> * build fixes Signed-off-by: raver119 <raver119@gmail.com> * OP/CONFIGURABLE_OP shapefn fix (#125) Signed-off-by: raver119 <raver119@gmail.com> * Some ops added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Nd4j refactoring (last one!) (#123) * fix: IOException no longer thrown by read(). Signed-off-by: Robert Altena <Rob@Ra-ai.com> * refactoring * last refactorings Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Advanced tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * [WIP] Java wrappers (#126) * shift/rshift/rotl/rotr java/sd wrappers Signed-off-by: raver119 <raver119@gmail.com> * few additional wrappers Signed-off-by: raver119 <raver119@gmail.com> * minor naming tweak Signed-off-by: raver119 <raver119@gmail.com> * Test added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * one more build fix Signed-off-by: raver119 <raver119@gmail.com> * Jar packaging for maven Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Typo fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * minimal viable prototype for SD Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests corrected Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Enabled test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test copied from nd4j Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Thin wrapper added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Adding arithmetic Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Stubs for ops Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Some ops added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Advanced tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Ops added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Small build fixes (#127) * Small build fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fix RL4J Signed-off-by: Alex Black <blacka101@gmail.com> * Test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Another fix Signed-off-by: Alex Black <blacka101@gmail.com> * parent module name fix Signed-off-by: raver119 <raver119@gmail.com> * [WIP] Roll rewritten (#128) * Process correct input vector. * Added tests for roll. * Refactored roll to conform with TF. Eliminated memory leaks with Roll op tests. * no thread_local for cpu Signed-off-by: raver119 <raver119@gmail.com> * Jar packaging for maven Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Typo fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * minimal viable prototype for SD Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests corrected Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Enabled test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test copied from nd4j Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Thin wrapper added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Adding arithmetic Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Stubs for ops Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Some ops added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Advanced tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Ops added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Boolen logic ops Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Shift operations Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
30b51f8085
commit
10d676e0b8
|
@ -578,7 +578,12 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.Range.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class
|
||||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class
|
||||
|
||||
);
|
||||
|
||||
static {
|
||||
|
|
15
nd4s/pom.xml
15
nd4s/pom.xml
|
@ -30,7 +30,7 @@
|
|||
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4s</artifactId>
|
||||
<packaging>pom</packaging>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>nd4s</name>
|
||||
|
||||
|
@ -280,6 +280,19 @@
|
|||
</gitDescribe>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>make-a-jar</id>
|
||||
<phase>compile</phase>
|
||||
<goals>
|
||||
<goal>jar</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4s.samediff
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.autodiff.samediff.SDVariable
|
||||
import org.nd4j.autodiff.samediff.SameDiff
|
||||
import org.nd4j.linalg.api.buffer.DataType
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
|
||||
/**
|
||||
* Provides wrappers for nd4j SameDiff and related classes.
|
||||
*
|
||||
* Wrappers are designed to be used implicitly, client code
|
||||
* should be similar to nd4j with additional syntactic sugar
|
||||
* and Scala specific stuff.
|
||||
*
|
||||
* @author Alexander Stoyakin
|
||||
*/
|
||||
class SameDiffWrapper {
|
||||
|
||||
var sd: SameDiff = SameDiff.create()
|
||||
|
||||
def this(sd: SameDiff) {
|
||||
this
|
||||
this.sd = sd
|
||||
}
|
||||
|
||||
def bind(name: String, data: INDArray): SDVariable =
|
||||
sd.`var`(name, data)
|
||||
|
||||
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
|
||||
sd.`var`(name, dataType, shape: _*)
|
||||
|
||||
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
|
||||
sd.`var`(name, dataType, shape: _*)
|
||||
|
||||
def placeHolder(name: String, dataType: DataType, shape: Long*): SDVariable =
|
||||
sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
|
||||
}
|
||||
|
||||
class SDVariableWrapper {
|
||||
|
||||
var thisVariable: SDVariable = null
|
||||
var isScalar: Boolean = false
|
||||
|
||||
def this(variable: SDVariable) {
|
||||
this
|
||||
thisVariable = variable
|
||||
}
|
||||
|
||||
def *(other: SDVariable): SDVariable =
|
||||
thisVariable.mul(other)
|
||||
|
||||
def +(other: SDVariable): SDVariable =
|
||||
thisVariable.add(other)
|
||||
|
||||
def /(other: SDVariable): SDVariable =
|
||||
if (isScalar)
|
||||
thisVariable.rdiv(other)
|
||||
else
|
||||
thisVariable.rdiv(other)
|
||||
|
||||
def -(other: SDVariable): SDVariable =
|
||||
if (isScalar)
|
||||
thisVariable.rsub(other)
|
||||
else
|
||||
thisVariable.sub(other)
|
||||
|
||||
def %(other: SDVariable): SDVariable = thisVariable.mod(null, other)
|
||||
|
||||
def `//`(other: SDVariable): SDVariable = thisVariable.fdiv(null, other)
|
||||
|
||||
def unary_-(): SDVariable = thisVariable.neg
|
||||
|
||||
def ^(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.xor(thisVariable, other)
|
||||
def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other)
|
||||
def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other)
|
||||
|
||||
def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x)
|
||||
def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x)
|
||||
|
||||
// Overloads for numeric arguments
|
||||
// Float
|
||||
def *(other: Float)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.mul(sameDiff.constant(other))
|
||||
|
||||
def +(other: Float)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.add(sameDiff.constant(other))
|
||||
|
||||
def -(other: Float)(implicit sameDiff: SameDiff): SDVariable =
|
||||
if (isScalar)
|
||||
thisVariable.rsub(sameDiff.constant(other))
|
||||
else
|
||||
thisVariable.sub(sameDiff.constant(other))
|
||||
|
||||
def /(other: Float)(implicit sameDiff: SameDiff): SDVariable =
|
||||
if (isScalar)
|
||||
thisVariable.rdiv(sameDiff.constant(other))
|
||||
else
|
||||
thisVariable.div(sameDiff.constant(other))
|
||||
|
||||
def %(other: Float)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.mod(null, sameDiff.constant(other))
|
||||
|
||||
def `//`(other: Float)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.fdiv(null, sameDiff.constant(other))
|
||||
|
||||
//Double
|
||||
def *(other: Double)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.mul(sameDiff.constant(other))
|
||||
|
||||
def +(other: Double)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.add(sameDiff.constant(other))
|
||||
|
||||
def -(other: Double)(implicit sameDiff: SameDiff): SDVariable =
|
||||
if (isScalar)
|
||||
thisVariable.rsub(sameDiff.constant(other))
|
||||
else
|
||||
thisVariable.sub(sameDiff.constant(other))
|
||||
|
||||
def /(other: Double)(implicit sameDiff: SameDiff): SDVariable =
|
||||
if (isScalar)
|
||||
thisVariable.rdiv(sameDiff.constant(other))
|
||||
else
|
||||
thisVariable.div(sameDiff.constant(other))
|
||||
|
||||
def %(other: Double)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.mod(null, sameDiff.constant(other))
|
||||
|
||||
def `//`(other: Double)(implicit sameDiff: SameDiff): SDVariable =
|
||||
thisVariable.fdiv(null, sameDiff.constant(other))
|
||||
|
||||
// Int
|
||||
def **(x: Int): SDVariable =
|
||||
thisVariable.pow(x)
|
||||
|
||||
def ^(other: Boolean)(implicit sameDiff: SameDiff): SDVariable =
|
||||
sameDiff.math.xor(thisVariable, sameDiff.constant(Nd4j.scalar(other)))
|
||||
def |(other: Boolean)(implicit sameDiff: SameDiff): SDVariable =
|
||||
sameDiff.math.or(thisVariable, sameDiff.constant(Nd4j.scalar(other)))
|
||||
def &(other: Boolean)(implicit sameDiff: SameDiff): SDVariable =
|
||||
sameDiff.math.and(thisVariable, sameDiff.constant(Nd4j.scalar(other)))
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4s.samediff.implicits
|
||||
|
||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper }
|
||||
|
||||
object Implicits {
|
||||
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
|
||||
new SameDiffWrapper(sd)
|
||||
|
||||
implicit def SDVariableToWrapper(variable: SDVariable): SDVariableWrapper =
|
||||
new SDVariableWrapper(variable)
|
||||
|
||||
implicit def FloatToSDVariable(x: Float)(implicit sd: SameDiff): SDVariableWrapper = {
|
||||
val result = new SDVariableWrapper(sd.constant(x))
|
||||
result.isScalar = true
|
||||
result
|
||||
}
|
||||
|
||||
implicit def DoubleToSDVariable(x: Double)(implicit sd: SameDiff): SDVariableWrapper = {
|
||||
val result = new SDVariableWrapper(sd.constant(x))
|
||||
result.isScalar = true
|
||||
result
|
||||
}
|
||||
|
||||
implicit def BooleanToSDVariable(x: Boolean)(implicit sd: SameDiff): SDVariableWrapper = {
|
||||
val result = new SDVariableWrapper(sd.constant(Nd4j.scalar(x)))
|
||||
result.isScalar = true
|
||||
result
|
||||
}
|
||||
}
|
|
@ -48,8 +48,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with double data and offset" in {
|
||||
val ndArray = (1 to 9).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4)
|
||||
it should "be able to extract a part of 2d matrix with double data" in {
|
||||
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
||||
|
||||
val expectedArray = Array(
|
||||
Array(5d, 6d),
|
||||
|
|
|
@ -303,7 +303,7 @@ class NDArrayProjectionAPITest extends FlatSpec {
|
|||
}
|
||||
|
||||
"SliceProjectedNDArray" should "filter slice correctly" in {
|
||||
val ndArray = (1d until 10d by 1).asNDArray(2, 2, 2)
|
||||
val ndArray = (1d until 9d by 1).asNDArray(2, 2, 2)
|
||||
val result = ndArray.sliceP withFilter (input => false)
|
||||
assert(result.filtered.isEmpty)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4s.samediff
|
||||
|
||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
||||
import org.nd4j.linalg.api.buffer.DataType
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.Implicits._
|
||||
import org.nd4s.samediff.implicits.Implicits._
|
||||
import org.scalatest.{ FlatSpec, Matchers }
|
||||
|
||||
class ConstructionTest extends FlatSpec with Matchers {
|
||||
|
||||
"SameDiff" should "allow composition of arithmetic operations" in {
|
||||
|
||||
val sd = SameDiff.create()
|
||||
val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
|
||||
val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5))
|
||||
val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5))
|
||||
|
||||
val mmul1 = ph1 * w1
|
||||
val badd1 = mmul1 + b1
|
||||
|
||||
val loss1 = badd1.std("loss1", true)
|
||||
|
||||
sd.setLossVariables("loss1")
|
||||
sd.createGradFunction
|
||||
for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
|
||||
assert(v.getVarName != null && v.gradient != null)
|
||||
}
|
||||
}
|
||||
|
||||
"SameDiff" should "provide arithmetic operations for float arguments in arbitrary order" in {
|
||||
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0f.toScalar)
|
||||
var evaluated = w1.eval.castTo(DataType.FLOAT)
|
||||
evaluated.toFloatVector.head shouldBe 4.0f
|
||||
|
||||
val w2 = w1 * 2.0f
|
||||
w2.eval.toFloatVector.head shouldBe 8.0f
|
||||
val w3 = w2 + 2.0f
|
||||
w3.eval.toFloatVector.head shouldBe 10.0f
|
||||
|
||||
val w4 = 2.0f * w1
|
||||
w4.eval.toFloatVector.head shouldBe 8.0f
|
||||
val w5 = 2.0f + w2
|
||||
w5.eval.toFloatVector.head shouldBe 10.0f
|
||||
|
||||
val w6 = w1 / 2.0f
|
||||
w6.eval.toFloatVector.head shouldBe 2.0f
|
||||
val w7 = w2 - 2.0f
|
||||
w7.eval.toFloatVector.head shouldBe 6.0f
|
||||
|
||||
val w8 = 2.0f / w1
|
||||
w8.eval.toFloatVector.head shouldBe 2.0f
|
||||
|
||||
val w9 = 2.0f - w2
|
||||
w9.eval.toFloatVector.head shouldBe 6.0f
|
||||
}
|
||||
|
||||
"SameDiff" should "provide arithmetic operations for double arguments in arbitrary order" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0.toScalar)
|
||||
var evaluated = w1.eval.castTo(DataType.DOUBLE)
|
||||
evaluated.toFloatVector.head shouldBe 4.0
|
||||
|
||||
val w2 = w1 * 2.0
|
||||
w2.eval.toFloatVector.head shouldBe 8.0
|
||||
val w3 = w2 + 2.0
|
||||
w3.eval.toFloatVector.head shouldBe 10.0
|
||||
|
||||
val w4 = 2.0 * w1
|
||||
w4.eval.toFloatVector.head shouldBe 8.0
|
||||
val w5 = 2.0 + w2
|
||||
w5.eval.toFloatVector.head shouldBe 10.0
|
||||
|
||||
val w6 = w1 / 2.0
|
||||
w6.eval.toFloatVector.head shouldBe 2.0
|
||||
val w7 = w2 - 2.0
|
||||
w7.eval.toFloatVector.head shouldBe 6.0
|
||||
|
||||
val w8 = 2.0 / w1
|
||||
w8.eval.toFloatVector.head shouldBe 2.0
|
||||
val w9 = 2.0 - w2
|
||||
w9.eval.toFloatVector.head shouldBe 6.0f
|
||||
}
|
||||
|
||||
"SameDiff" should "provide unary math operators" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0.toScalar)
|
||||
var evaluated = w1.eval.castTo(DataType.DOUBLE)
|
||||
evaluated.toFloatVector.head shouldBe 4.0
|
||||
|
||||
val w2 = -w1
|
||||
var evaluated2 = w2.eval.castTo(DataType.DOUBLE)
|
||||
evaluated2.toFloatVector.head shouldBe -4.0
|
||||
|
||||
val w3 = w1 ** 2
|
||||
var evaluated3 = w3.eval.castTo(DataType.DOUBLE)
|
||||
evaluated3.toFloatVector.head shouldBe 16.0
|
||||
}
|
||||
}
|
|
@ -0,0 +1,191 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4s.samediff
|
||||
|
||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
||||
import org.nd4j.linalg.api.buffer.DataType
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.Implicits._
|
||||
import org.nd4s.samediff.implicits.Implicits._
|
||||
import org.scalatest.{ FlatSpec, Matchers }
|
||||
|
||||
class MathTest extends FlatSpec with Matchers {
|
||||
|
||||
"SameDiff" should "allow composition of arithmetic operations" in {
|
||||
|
||||
val sd = SameDiff.create()
|
||||
val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
|
||||
val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5))
|
||||
val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5))
|
||||
|
||||
val mmul1 = ph1 * w1
|
||||
val badd1 = mmul1 + b1
|
||||
|
||||
val loss1 = badd1.std("loss1", true)
|
||||
|
||||
sd.setLossVariables("loss1")
|
||||
sd.createGradFunction
|
||||
for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
|
||||
assert(v.getVarName != null && v.gradient != null)
|
||||
}
|
||||
}
|
||||
|
||||
"SameDiff" should "provide arithmetic operations for float arguments in arbitrary order" in {
|
||||
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0f.toScalar)
|
||||
var evaluated = w1.eval.castTo(DataType.FLOAT)
|
||||
evaluated.toFloatVector.head shouldBe 4.0f
|
||||
|
||||
val w2 = w1 * 2.0f
|
||||
w2.eval.toFloatVector.head shouldBe 8.0f
|
||||
val w3 = w2 + 2.0f
|
||||
w3.eval.toFloatVector.head shouldBe 10.0f
|
||||
|
||||
val w4 = 2.0f * w1
|
||||
w4.eval.toFloatVector.head shouldBe 8.0f
|
||||
val w5 = 2.0f + w2
|
||||
w5.eval.toFloatVector.head shouldBe 10.0f
|
||||
|
||||
val w6 = w1 / 2.0f
|
||||
w6.eval.toFloatVector.head shouldBe 2.0f
|
||||
val w7 = w2 - 2.0f
|
||||
w7.eval.toFloatVector.head shouldBe 6.0f
|
||||
|
||||
val w8 = 2.0f / w1
|
||||
w8.eval.toFloatVector.head shouldBe 2.0f
|
||||
|
||||
val w9 = 2.0f - w2
|
||||
w9.eval.toFloatVector.head shouldBe 6.0f
|
||||
}
|
||||
|
||||
"SameDiff" should "provide arithmetic operations for double arguments in arbitrary order" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0.toScalar)
|
||||
var evaluated = w1.eval.castTo(DataType.DOUBLE)
|
||||
evaluated.toFloatVector.head shouldBe 4.0
|
||||
|
||||
val w2 = w1 * 2.0
|
||||
w2.eval.toFloatVector.head shouldBe 8.0
|
||||
val w3 = w2 + 2.0
|
||||
w3.eval.toFloatVector.head shouldBe 10.0
|
||||
|
||||
val w4 = 2.0 * w1
|
||||
w4.eval.toFloatVector.head shouldBe 8.0
|
||||
val w5 = 2.0 + w2
|
||||
w5.eval.toFloatVector.head shouldBe 10.0
|
||||
|
||||
val w6 = w1 / 2.0
|
||||
w6.eval.toFloatVector.head shouldBe 2.0
|
||||
val w7 = w2 - 2.0
|
||||
w7.eval.toFloatVector.head shouldBe 6.0
|
||||
|
||||
val w8 = 2.0 / w1
|
||||
w8.eval.toFloatVector.head shouldBe 2.0
|
||||
val w9 = 2.0 - w2
|
||||
w9.eval.toFloatVector.head shouldBe 6.0f
|
||||
}
|
||||
|
||||
"SameDiff" should "provide floor division" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0.toScalar)
|
||||
val w2 = sd.bind("w2", 1.2.toScalar)
|
||||
val w3 = w1 `//` w2
|
||||
w3.eval.toFloatVector.head shouldBe 3.0
|
||||
|
||||
val w4 = w1 `//` 1.5
|
||||
w4.eval.toFloatVector.head shouldBe 2.0
|
||||
|
||||
val w5 = 9.5 `//` w1
|
||||
w5.eval.toFloatVector.head shouldBe 2.0
|
||||
}
|
||||
|
||||
"SameDiff" should "provide remainder division" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 40.0.toScalar)
|
||||
val w2 = sd.bind("w2", 12.0.toScalar)
|
||||
val w3 = w2 % w1
|
||||
w3.eval.toFloatVector.head shouldBe 12.0
|
||||
val w4 = w1 % w2
|
||||
w4.eval.toFloatVector.head shouldBe 4.0
|
||||
|
||||
val w5 = w1 % 15.0
|
||||
w5.eval.toFloatVector.head shouldBe 10.0
|
||||
|
||||
val w6 = 10.0 % w1
|
||||
w6.eval.toFloatVector.head shouldBe 10.0
|
||||
}
|
||||
|
||||
"SameDiff" should "provide unary math operators" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.bind("w1", 4.0.toScalar)
|
||||
var evaluated = w1.eval.castTo(DataType.DOUBLE)
|
||||
evaluated.toFloatVector.head shouldBe 4.0
|
||||
|
||||
val w2 = -w1
|
||||
var evaluated2 = w2.eval.castTo(DataType.DOUBLE)
|
||||
evaluated2.toFloatVector.head shouldBe -4.0
|
||||
|
||||
val w3 = w1 ** 2
|
||||
var evaluated3 = w3.eval.castTo(DataType.DOUBLE)
|
||||
evaluated3.toFloatVector.head shouldBe 16.0
|
||||
}
|
||||
|
||||
"SameDiff" should "provide boolean logic operators" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.constant(Nd4j.scalar(true))
|
||||
val w2 = sd.constant(Nd4j.scalar(true))
|
||||
|
||||
val w3 = w1 | w2
|
||||
w3.eval.toIntVector.head shouldBe 1
|
||||
|
||||
val w4 = w1 & w2
|
||||
w4.eval.toIntVector.head shouldBe 1
|
||||
|
||||
val w5 = w1 ^ w2
|
||||
w5.eval.toIntVector.head shouldBe 0
|
||||
|
||||
val w6 = w1 | false
|
||||
w6.eval.toIntVector.head shouldBe 1
|
||||
|
||||
val w7 = w1 & false
|
||||
w7.eval.toIntVector.head shouldBe 0
|
||||
|
||||
val w8 = w1 ^ false
|
||||
w8.eval.toIntVector.head shouldBe 1
|
||||
|
||||
val w9 = false | w1
|
||||
w9.eval.toIntVector.head shouldBe 1
|
||||
|
||||
val w10 = false & w1
|
||||
w10.eval.toIntVector.head shouldBe 0
|
||||
|
||||
val w11 = false ^ w1
|
||||
w11.eval.toIntVector.head shouldBe 1
|
||||
}
|
||||
|
||||
"SameDiff" should "provide shifting operations" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.constant(16)
|
||||
|
||||
val w2 = w1 << 2
|
||||
w2.eval.toIntVector.head shouldBe 64
|
||||
|
||||
val w3 = w1 >> 2
|
||||
w3.eval.toIntVector.head shouldBe 4
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package org.nd4s.samediff
|
||||
|
||||
import java.lang.reflect.Field
|
||||
import java.util
|
||||
import java.util.{ Arrays, Collections, HashMap, List, Map }
|
||||
|
||||
import com.google.common.collect.{ Lists, Maps }
|
||||
import org.junit.Assert._
|
||||
import org.junit.Assume.assumeNotNull
|
||||
import org.nd4j.autodiff.samediff._
|
||||
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional
|
||||
import org.nd4j.autodiff.validation.{ OpValidation, TestCase }
|
||||
import org.nd4j.linalg.activations.Activation
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose
|
||||
import org.nd4j.linalg.api.buffer.DataType
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
||||
import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear }
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig }
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance
|
||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.{ OldMax, OldMin }
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom._
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor
|
||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil
|
||||
import org.nd4j.linalg.dataset.{ DataSet, MultiDataSet }
|
||||
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex.all
|
||||
import org.nd4j.linalg.learning.config.Adam
|
||||
import org.nd4j.linalg.ops.transforms.Transforms
|
||||
import org.nd4j.weightinit.impl.{ OneInitScheme, UniformInitScheme, ZeroInitScheme }
|
||||
import org.nd4s.samediff.implicits.Implicits._
|
||||
import org.scalatest.{ FlatSpec, Matchers }
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
class SameDiffTest extends FlatSpec with Matchers {
|
||||
|
||||
"SameDiff" should "allow Mse backwards execution" in {
|
||||
|
||||
implicit val sd: SameDiff = SameDiff.create
|
||||
|
||||
val nOut: Int = 4
|
||||
val minibatch: Int = 3
|
||||
val input: SDVariable = sd.bind("in", DataType.FLOAT, Array[Long](minibatch, nOut))
|
||||
val label: SDVariable = sd.bind("label", DataType.FLOAT, Array[Long](minibatch, nOut))
|
||||
|
||||
val diff: SDVariable = input - label
|
||||
val sqDiff: SDVariable = diff * diff
|
||||
//val sqDiff: SDVariable = diff ** 2
|
||||
val msePerEx: SDVariable = sd.mean("msePerEx", sqDiff, 1)
|
||||
val avgMSE: SDVariable = sd.mean("loss", msePerEx, 0)
|
||||
|
||||
val inputArr: INDArray = Nd4j.rand(DataType.FLOAT, minibatch, nOut)
|
||||
val labelArr: INDArray = Nd4j.rand(DataType.FLOAT, minibatch, nOut)
|
||||
|
||||
sd.associateArrayWithVariable(inputArr, input)
|
||||
sd.associateArrayWithVariable(labelArr, label)
|
||||
|
||||
val result: INDArray = sd.execAndEndResult
|
||||
assertEquals(1, result.length)
|
||||
|
||||
val emptyMap = new HashMap[String, INDArray]()
|
||||
sd.execBackwards(emptyMap)
|
||||
}
|
||||
|
||||
"SameDiff" should "run test dense layer forward pass" in {
|
||||
Nd4j.getRandom.setSeed(12345)
|
||||
implicit val sd = SameDiff.create
|
||||
val iInput = Nd4j.rand(3, 4)
|
||||
val iWeights = Nd4j.rand(4, 5)
|
||||
val iBias = Nd4j.rand(1, 5)
|
||||
val input = sd.bind("input", iInput)
|
||||
val weights = sd.bind("weights", iWeights)
|
||||
val bias = sd.bind("bias", iBias)
|
||||
val mmul = sd.mmul("mmul", input, weights)
|
||||
|
||||
val z = mmul + bias
|
||||
|
||||
val out = sd.nn.sigmoid("out", z)
|
||||
val expMmul = iInput.mmul(iWeights)
|
||||
val expZ = expMmul.addRowVector(iBias)
|
||||
val expOut = Transforms.sigmoid(expZ, true)
|
||||
sd.exec(new HashMap[String, INDArray](), sd.outputs)
|
||||
assertEquals(expMmul, mmul.getArr)
|
||||
assertEquals(expZ, z.getArr)
|
||||
assertEquals(expOut, out.getArr)
|
||||
}
|
||||
|
||||
"SameDiff" should "convert placeholder to constant" in {
|
||||
Nd4j.getRandom.setSeed(12345)
|
||||
val sd = SameDiff.create
|
||||
val in = sd.placeHolder("in", DataType.FLOAT, 1, 3)
|
||||
val in2 = sd.placeHolder("in2", DataType.FLOAT, 3, 4)
|
||||
val b = sd.bind("b", Nd4j.rand(DataType.FLOAT, 1, 4))
|
||||
val mmul = in.mmul(in2)
|
||||
val add = mmul + b
|
||||
val tanh = sd.math.tanh(add)
|
||||
val loss = sd.variance(tanh, true)
|
||||
val inArr = Nd4j.rand(DataType.FLOAT, 1, 3)
|
||||
in.setArray(inArr)
|
||||
val inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4)
|
||||
val c = TrainingConfig.builder
|
||||
.updater(new Adam(0.1))
|
||||
.weightDecay(0.01, true)
|
||||
.dataSetFeatureMapping("in", "in2")
|
||||
.skipBuilderValidation(true)
|
||||
.build
|
||||
sd.setTrainingConfig(c)
|
||||
sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr, inArr2), null)), 1)
|
||||
val out = tanh.eval
|
||||
in.convertToConstant
|
||||
val out2 = tanh.eval
|
||||
assertEquals(out, out2)
|
||||
assertEquals(VariableType.CONSTANT, in.getVariableType)
|
||||
assertEquals(inArr, in.getArr)
|
||||
//Sanity check on fitting:
|
||||
sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr2), null)), 1)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
package org.nd4s.samediff
|
||||
|
||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff, TrainingConfig }
|
||||
import org.nd4j.linalg.api.buffer.DataType
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.dataset.{ DataSet, MultiDataSet }
|
||||
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.learning.config.Adam
|
||||
import org.nd4s.Implicits._
|
||||
import org.nd4s.samediff.implicits.Implicits._
|
||||
import org.scalatest.{ FlatSpec, Matchers }
|
||||
|
||||
class TrainingTest extends FlatSpec with Matchers {
|
||||
|
||||
"SameDiff" should "allow loss calculation" in {
|
||||
for (i <- 0 until 2) {
|
||||
implicit val sd = SameDiff.create
|
||||
val ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4)
|
||||
val w = sd.bind("w", Nd4j.rand(DataType.FLOAT, 4, 5))
|
||||
val b = sd.bind("b", Nd4j.rand(DataType.FLOAT, 5))
|
||||
val mmul = ph.mmul(w)
|
||||
val badd = mmul + b
|
||||
val add = badd + 1
|
||||
val shape = add.shape
|
||||
val unused1 = ph.mul(2)
|
||||
val unused2 = ph.sub(4)
|
||||
val unused3 = unused1.div(unused2)
|
||||
val loss1 = add.std("l1", true)
|
||||
val loss2 = mmul.mean("l2")
|
||||
Console.println(sd.summary)
|
||||
if (i == 0) {
|
||||
sd.setLossVariables("l1", "l2")
|
||||
sd.createGradFunction()
|
||||
} else {
|
||||
val tc = TrainingConfig.builder
|
||||
.updater(new Adam(0.01))
|
||||
.minimize("l1", "l2")
|
||||
.dataSetFeatureMapping("ph")
|
||||
.markLabelsUnused
|
||||
.build
|
||||
sd.setTrainingConfig(tc)
|
||||
val ds = new DataSet(Nd4j.create(3, 4), null)
|
||||
sd.fit(ds)
|
||||
sd.fit(ds)
|
||||
}
|
||||
for (s <- Array[String]("w", "b", badd.getVarName, add.getVarName, "l1", "l2")) {
|
||||
val gradVar = sd.getVariable(s).gradient
|
||||
assert(gradVar != null)
|
||||
}
|
||||
//Unused:
|
||||
assert(!shape.hasGradient)
|
||||
try assert(shape.gradient == null)
|
||||
catch {
|
||||
case e: IllegalStateException =>
|
||||
assert(e.getMessage.contains("only floating point variables"))
|
||||
}
|
||||
for (s <- Array[String](unused1.getVarName, unused2.getVarName, unused3.getVarName)) {
|
||||
assert(sd.getVariable(s).gradient == null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"SameDiff" should "allow creating and running model with 2 losses: train on the first one, then change losses" in {
|
||||
// TODO: try to get rid of implicit here
|
||||
implicit val sd = SameDiff.create
|
||||
val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
|
||||
val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5))
|
||||
val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5))
|
||||
val mmul1 = ph1.mmul(w1)
|
||||
val badd1 = mmul1 + b1
|
||||
|
||||
val ph2 = sd.placeHolder("ph2", DataType.FLOAT, 3, 2)
|
||||
val w2 = sd.bind("w2", Nd4j.rand(DataType.FLOAT, 2, 6))
|
||||
val b2 = sd.bind("b2", Nd4j.rand(DataType.FLOAT, 6))
|
||||
val mmul2 = ph2.mmul(w2)
|
||||
val badd2 = mmul2 + b2
|
||||
val loss1 = badd1.std("loss1", true)
|
||||
val loss2 = badd2.std("loss2", true)
|
||||
//First: create grad function for optimizing loss 1 only
|
||||
sd.setLossVariables("loss1")
|
||||
sd.createGradFunction()
|
||||
for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
|
||||
assert(v.gradient != null)
|
||||
}
|
||||
for (v <- Array[SDVariable](ph2, w2, b2, mmul2, badd2, loss2)) {
|
||||
assert(v.gradient == null)
|
||||
}
|
||||
//Now, set to other loss function
|
||||
sd.setLossVariables("loss2")
|
||||
sd.createGradFunction()
|
||||
for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
|
||||
assert(v.gradient == null)
|
||||
}
|
||||
for (v <- Array[SDVariable](ph2, w2, b2, mmul2, badd2, loss2)) {
|
||||
assert(v.gradient != null)
|
||||
}
|
||||
//Train the first side of the graph. The other side should remain unmodified!
|
||||
sd.setLossVariables("loss1")
|
||||
var w1Before = w1.getArr.dup
|
||||
var b1Before = b1.getArr.dup
|
||||
var w2Before = w2.getArr.dup
|
||||
var b2Before = b2.getArr.dup
|
||||
val tc = TrainingConfig.builder.updater(new Adam(1e-2)).dataSetFeatureMapping("ph1", "ph2").markLabelsUnused.build
|
||||
sd.setTrainingConfig(tc)
|
||||
val mds = new MultiDataSet(Array[INDArray](Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.rand(DataType.FLOAT, 3, 2)),
|
||||
new Array[INDArray](0))
|
||||
sd.fit(new SingletonMultiDataSetIterator(mds), 3)
|
||||
assert(w1Before != w1.getArr)
|
||||
assert(b1Before != b1.getArr)
|
||||
assert(w2Before == w2.getArr)
|
||||
assert(b2Before == b2.getArr)
|
||||
//Train second side of graph; first side should be unmodified
|
||||
sd.setLossVariables("loss2")
|
||||
w1Before = w1.getArr.dup
|
||||
b1Before = b1.getArr.dup
|
||||
w2Before = w2.getArr.dup
|
||||
b2Before = b2.getArr.dup
|
||||
sd.fit(new SingletonMultiDataSetIterator(mds), 3)
|
||||
assert(w1Before == w1.getArr)
|
||||
assert(b1Before == b1.getArr)
|
||||
assert(w2Before != w2.getArr)
|
||||
assert(b2Before != b2.getArr)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue