Fix SDLoss null weights array issue (#185)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2020-01-25 20:13:23 +11:00 committed by GitHub
parent 4db28a9300
commit 458d141d8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -45,7 +45,7 @@ public class SDLoss extends SDOps {
*/
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
String weightName = (name == null) ? null : name + "/weight";
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights;
}
/**

View File

@ -17,9 +17,10 @@
package org.nd4j.systeminfo;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.systeminfo.SystemInfo;
public class TestSystemInfo {
public class TestSystemInfo extends BaseND4JTest {
@Test
public void testSystemInfo(){
SystemInfo.printSystemInfo();