parent
4db28a9300
commit
458d141d8e
|
@ -45,7 +45,7 @@ public class SDLoss extends SDOps {
|
||||||
*/
|
*/
|
||||||
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
|
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
|
||||||
String weightName = (name == null) ? null : name + "/weight";
|
String weightName = (name == null) ? null : name + "/weight";
|
||||||
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,9 +17,10 @@
|
||||||
package org.nd4j.systeminfo;
|
package org.nd4j.systeminfo;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.systeminfo.SystemInfo;
|
import org.nd4j.systeminfo.SystemInfo;
|
||||||
|
|
||||||
public class TestSystemInfo {
|
public class TestSystemInfo extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testSystemInfo(){
|
public void testSystemInfo(){
|
||||||
SystemInfo.printSystemInfo();
|
SystemInfo.printSystemInfo();
|
||||||
|
|
Loading…
Reference in New Issue