Fix SDLoss null weights array issue ()

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2020-01-25 20:13:23 +11:00 committed by GitHub
parent 4db28a9300
commit 458d141d8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions
nd4j/nd4j-backends
nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops
nd4j-tests/src/test/java/org/nd4j/systeminfo

View File

@ -45,7 +45,7 @@ public class SDLoss extends SDOps {
*/
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
String weightName = (name == null) ? null : name + "/weight";
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights;
}
/**

View File

@ -17,9 +17,10 @@
package org.nd4j.systeminfo;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.systeminfo.SystemInfo;
public class TestSystemInfo {
public class TestSystemInfo extends BaseND4JTest {
@Test
public void testSystemInfo(){
SystemInfo.printSystemInfo();