diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml index 56d1013bf..03a335ea8 100644 --- a/arbiter/arbiter-ui/pom.xml +++ b/arbiter/arbiter-ui/pom.xml @@ -26,78 +26,13 @@ 4.0.0 - arbiter-ui_2.11 + arbiter-ui arbiter-ui 1.8 - - - - - buildUiTemplates - - false - - - - - com.google.code.play2-maven-plugin - play2-maven-plugin - ${maven-play2-plugin.version} - - - - - GenerateTemplates - compile - - - - template-compile - - - - - - - - maven-resources-plugin - ${maven-resources-plugin.version} - - - CopyTemplates - compile - - copy-resources - - - ${basedir}/src/main/scala/org/deeplearning4j/arbiter/ui/views/html - - - ${basedir}/target/twirl/main/org/deeplearning4j/arbiter/ui/views/html - true - - - - - - - - - - test-nd4j-native @@ -115,10 +50,17 @@ org.deeplearning4j - deeplearning4j-ui_2.11 + deeplearning4j-ui ${dl4j.version} + + ch.qos.logback + logback-classic + test + ${logback.version} + + org.deeplearning4j arbiter-deeplearning4j @@ -134,72 +76,6 @@ - - - org.apache.maven.wagon - wagon-http - 2.9 - - - - - - - org.codehaus.mojo - build-helper-maven-plugin - ${maven-build-helper-plugin.version} - - - add-source - compile - - add-source - - - - src/main/scala - src/main/java - src/main/views - - - - - - - - net.alchim31.maven - scala-maven-plugin - ${maven-scala-plugin.version} - - - - compile - testCompile - - - - - ${scala.version} - - - - - com.google.code.sbt-compiler-maven-plugin - sbt-compiler-maven-plugin - ${sbt-compiler-maven-plugin.version} - - - - - org.apache.maven.plugins - maven-compiler-plugin - - 1.8 - 1.8 - - - - diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java index 68660a32f..3b5264f0d 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java @@ -20,8 +20,6 @@ import lombok.AllArgsConstructor; import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.arbiter.ui.module.ArbiterModule; -import org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport; -import scala.annotation.meta.field; import java.io.*; import java.lang.reflect.Field; diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index a7558451d..2f6bfc434 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,8 @@ package org.deeplearning4j.arbiter.ui.module; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.vertx.ext.web.RoutingContext; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.StatsStorage; @@ -31,8 +34,8 @@ import org.deeplearning4j.arbiter.ui.UpdateStatus; import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; import org.deeplearning4j.arbiter.ui.misc.UIUtils; -import org.deeplearning4j.arbiter.ui.views.html.ArbiterUI; import org.deeplearning4j.arbiter.util.ObjectUtils; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.ui.api.Component; import org.deeplearning4j.ui.api.*; import org.deeplearning4j.ui.components.chart.ChartLine; @@ -48,18 +51,15 @@ import org.deeplearning4j.ui.i18n.I18NResource; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import org.nd4j.linalg.primitives.Pair; -import play.libs.Json; -import play.mvc.Result; -import play.mvc.Results; +import org.nd4j.shade.jackson.core.JsonProcessingException; import java.awt.*; import java.text.DecimalFormat; -import java.util.*; import java.util.List; +import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; import static org.deeplearning4j.arbiter.ui.misc.JsonMapper.asJson; -import static play.mvc.Results.ok; /** * A Deeplearning4j {@link UIModule}, for integration with DL4J's user interface @@ -73,8 +73,6 @@ public class ArbiterModule implements UIModule { private static final DateTimeFormatter TIME_FORMATTER = DateTimeFormat.forPattern("YYYY-MM-dd HH:mm ZZ"); public static final String ARBITER_UI_TYPE_ID = "ArbiterUI"; - private static final String JSON = "application/json"; - private AtomicBoolean loggedArbiterAddress = new AtomicBoolean(false); private Map knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); private String currentSessionID; @@ -138,17 +136,18 @@ public class ArbiterModule implements UIModule { @Override public List getRoutes() { - Route r1 = new Route("/arbiter", HttpMethod.GET, FunctionType.Supplier, () -> Results.ok(ArbiterUI.apply())); - Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, FunctionType.Supplier, this::getLastUpdateTime); - Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, FunctionType.Function, this::getModelLastUpdateTimes); - Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, FunctionType.Function, this::getCandidateInfo); - Route r6 = new Route("/arbiter/config", HttpMethod.GET, FunctionType.Supplier, this::getOptimizationConfig); - Route r7 = new Route("/arbiter/results", HttpMethod.GET, FunctionType.Supplier, this::getSummaryResults); - Route r8 = new Route("/arbiter/summary", HttpMethod.GET, FunctionType.Supplier, this::getSummaryStatus); + Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() + .putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html")); + Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc)); + Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)); + Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc)); + Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc)); + Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc)); + Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc)); - Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, FunctionType.Supplier, this::listSessions); - Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, FunctionType.Supplier, this::currentSession); - Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, FunctionType.Function, this::setSession); + Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc)); + Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)); + Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)); return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c); } @@ -197,23 +196,27 @@ public class ArbiterModule implements UIModule { getDefaultSession(); } - private Result currentSession() { + private void currentSession(RoutingContext rc) { String sid = currentSessionID == null ? "" : currentSessionID; - return ok(asJson(sid)).as(JSON); + rc.response() + .putHeader("content-type", "application/json") + .end(asJson(sid)); } - private Result listSessions() { - return Results.ok(asJson(knownSessionIDs.keySet())).as(JSON); + private void listSessions(RoutingContext rc) { + rc.response() + .putHeader("content-type", "application/json") + .end(asJson(knownSessionIDs.keySet())); } - private Result setSession(String newSessionID) { + private void setSession(String newSessionID, RoutingContext rc) { log.debug("Arbiter UI: Set to session {}", newSessionID); if (knownSessionIDs.containsKey(newSessionID)) { currentSessionID = newSessionID; - return ok(); + rc.response().end(); } else { - return Results.badRequest("Unknown session ID: " + newSessionID); + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()).end("Unknown session ID: " + newSessionID); } } @@ -255,30 +258,39 @@ public class ArbiterModule implements UIModule { } /** - * @return Last update time for the page + * Return the last update time for the page */ - private Result getLastUpdateTime(){ + private void getLastUpdateTime(RoutingContext rc){ //TODO - this forces updates on every request... which is fine, just inefficient long t = System.currentTimeMillis(); UpdateStatus us = new UpdateStatus(t, t, t); - return ok(Json.toJson(us)); + rc.response().putHeader("content-type", "application/json").end(asJson(us)); + } + + private String asJson(Object o){ + try{ + return JsonMappers.getMapper().writeValueAsString(o); + } catch (JsonProcessingException e){ + throw new RuntimeException("Error converting object to JSON", e); + } } /** * Get the last update time for the specified model IDs * @param modelIDs Model IDs to get the update time for */ - private Result getModelLastUpdateTimes(String modelIDs){ - + private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){ if(currentSessionID == null){ - return ok(); + rc.response().end(); + return; } StatsStorage ss = knownSessionIDs.get(currentSessionID); if(ss == null){ log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); - return ok("-1"); + rc.response().end("-1"); + return; } String[] split = modelIDs.split(","); @@ -292,7 +304,7 @@ public class ArbiterModule implements UIModule { } } - return ok(Json.toJson(lastUpdateTimes)); + rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes)); } /** @@ -301,12 +313,13 @@ public class ArbiterModule implements UIModule { * @param candidateId ID for the candidate * @return Content/info for the candidate */ - private Result getCandidateInfo(String candidateId){ + private void getCandidateInfo(String candidateId, RoutingContext rc){ StatsStorage ss = knownSessionIDs.get(currentSessionID); if(ss == null){ log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); - return ok(); + rc.response().end(); + return; } GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);; @@ -316,7 +329,10 @@ public class ArbiterModule implements UIModule { if(p == null){ String title = "No results found for model " + candidateId + "."; ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); - return ok(asJson(ct)).as(JSON); + rc.response() + .putHeader("content-type", "application/json") + .end(asJson(ct)); + return; } ModelInfoPersistable mip = (ModelInfoPersistable)p; @@ -474,25 +490,27 @@ public class ArbiterModule implements UIModule { ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); - return ok(asJson(cd)).as(JSON); + rc.response().putHeader("content-type", "application/json").end(asJson(cd)); } /** * Get the optimization configuration - second section in the page */ - private Result getOptimizationConfig(){ + private void getOptimizationConfig(RoutingContext rc){ StatsStorage ss = knownSessionIDs.get(currentSessionID); if(ss == null){ log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); - return ok(); + rc.response().end(); + return; } Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); if(p == null){ - log.info("No static info"); - return ok(); + log.debug("No static info"); + rc.response().end(); + return; } List components = new ArrayList<>(); @@ -574,14 +592,15 @@ public class ArbiterModule implements UIModule { ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); - return ok(asJson(cd)).as(JSON); + rc.response().putHeader("content-type", "application/json").end(asJson(cd)); } - private Result getSummaryResults(){ + private void getSummaryResults(RoutingContext rc){ StatsStorage ss = knownSessionIDs.get(currentSessionID); if(ss == null){ log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID); - return ok(); + rc.response().end(); + return; } List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); @@ -592,24 +611,26 @@ public class ArbiterModule implements UIModule { table.add(new String[]{mip.getModelIdx().toString(), score, mip.getStatus().toString()}); } - return ok(asJson(table)).as(JSON); + rc.response().putHeader("content-type", "application/json").end(asJson(table)); } /** * Get summary status information: first section in the page */ - private Result getSummaryStatus(){ + private void getSummaryStatus(RoutingContext rc){ StatsStorage ss = knownSessionIDs.get(currentSessionID); if(ss == null){ log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); - return ok(); + rc.response().end(); + return; } Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); if(p == null){ log.info("No static info"); - return ok(); + rc.response().end(); + return; } GlobalConfigPersistable gcp = (GlobalConfigPersistable)p; @@ -711,7 +732,7 @@ public class ArbiterModule implements UIModule { ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); - return ok(asJson(cd)).as(JSON); + rc.response().putHeader("content-type", "application/json").end(asJson(cd)); } diff --git a/arbiter/arbiter-ui/src/main/views/org/deeplearning4j/arbiter/ui/views/ArbiterUI.scala.html b/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html similarity index 98% rename from arbiter/arbiter-ui/src/main/views/org/deeplearning4j/arbiter/ui/views/ArbiterUI.scala.html rename to arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html index 372de64c5..9a5b6a998 100644 --- a/arbiter/arbiter-ui/src/main/views/org/deeplearning4j/arbiter/ui/views/ArbiterUI.scala.html +++ b/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html @@ -1,5 +1,6 @@ - - - - - DL4J - Arbiter UI - - - - - - - - - - - - - - - - - - - - -
-
Deeplearning4J - Arbiter UI
- -
- - -
-
-
-

Summary

-
-
-
-
- -
-
-

Optimization Settings

-
-
-
- - -
-
Results
-
- - - - - - -
-
-
- -
-
-

Selected Result

-
-
-
-
- -""")) - } - } - } - - def render(): play.twirl.api.HtmlFormat.Appendable = apply() - - def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply() - - def ref: this.type = this - -} - - -} - -/**/ -object ArbiterUI extends ArbiterUI_Scope0.ArbiterUI - /* - -- GENERATED -- - DATE: Thu Sep 05 18:53:38 AEST 2019 - SOURCE: C:/DL4J/Git/deeplearning4j/arbiter/arbiter-ui/src/main/views/org/deeplearning4j/arbiter/ui/views/ArbiterUI.scala.html - HASH: aa565fc00b748d09bde27d54caef529f2933efbb - MATRIX: 647->0|1821->1146|1850->1147|1892->1161|2123->1365|2152->1366|2191->1378|2230->1389|2259->1390|2301->1404|2402->1478|2431->1479|2470->1491|2507->1500|2536->1501|2578->1515|2641->1551|2670->1552|2709->1564|2740->1567|2769->1568|2811->1582|3046->1790|3075->1791|3114->1803|3145->1806|3174->1807|3216->1821|3453->2031|3482->2032|3521->2044|3555->2050|3584->2051|3626->2065|3730->2142|3759->2143|3798->2155|3845->2174|3874->2175|3916->2189|4088->2334|4117->2335|4156->2347|4252->2415|4281->2416|4323->2430|4465->2545|4494->2546|4533->2558|4583->2580|4612->2581|4654->2595|4825->2739|4854->2740|4893->2752|4943->2774|4972->2775|5014->2789|5144->2892|5173->2893|5212->2905|5372->3037|5401->3038|5443->3052|5596->3178|5625->3179|5662->3189|5719->3218|5748->3219|5791->3233|6048->3462|6078->3463|6118->3475|6205->3533|6235->3534|6278->3548|6372->3614|6402->3615|6442->3627|6580->3736|6610->3737|6653->3751|6733->3803|6763->3804|6803->3816|6851->3835|6881->3836|6924->3850|7507->4405|7537->4406|7577->4418|7624->4436|7654->4437|7697->4451|7756->4482|7786->4483|7826->4495|7879->4519|7909->4520|7952->4534|8011->4565|8041->4566|8081->4578|8211->4678|8242->4679|8285->4693|8723->5103|8753->5104|8793->5116|8855->5149|8885->5150|8928->5164|9259->5467|9289->5468|9329->5480|9391->5513|9421->5514|9464->5528|9575->5611|9605->5612|9643->5622|9695->5645|9725->5646|9768->5660|9886->5750|9916->5751|9954->5761|9994->5772|10024->5773|10067->5787|10160->5852|10190->5853|10230->5865|11496->7102|11526->7103|11564->7113|11695->7215|11725->7216|11768->7230|12332->7765|12362->7766|12409->7784|12552->7898|12582->7899|12674->7962|12704->7963|12751->7981|12863->8064|12893->8065|12944->8087|13201->8315|13231->8316|13286->8342|13341->8368|13371->8369|13430->8399|13505->8445|13535->8446|13590->8472|13713->8566|13743->8567|13796->8591|13989->8755|14019->8756|14049->8758|14177->8857|14207->8858|14429->9051|14459->9052|14506->9070|14640->9175|14670->9176|14721->9198|15033->9481|15063->9482|15162->9552|15192->9553|15237->9569|15361->9664|15391->9665|15438->9683|15548->9764|15578->9765|15629->9787|15924->10053|15954->10054|16057->10128|16087->10129|16132->10145|16287->10271|16317->10272|16366->10292|16501->10398|16531->10399|16582->10421|16770->10580|16800->10581|16901->10653|16931->10654|16976->10670|17087->10752|17117->10753|17164->10771|17303->10881|17333->10882|17384->10904|17697->11188|17727->11189|17772->11205|17802->11206|17840->11216|17870->11217|17905->11224|17934->11225|17970->11233|18046->11280|18076->11281|18114->11291|18371->11519|18401->11520|18444->11534|18520->11582|18550->11583|18590->11595|18821->11797|18851->11798|18894->11812|19027->11916|19057->11917|19104->11935|19197->11999|19227->12000|19270->12014|19375->12091|19405->12092|19445->12104|19484->12114|19514->12115|19557->12129|19614->12157|19644->12158|19691->12176|19873->12329|19903->12330|19954->12352|20044->12413|20074->12414|20121->12432|20223->12505|20253->12506|20291->12516|20321->12517|20361->12529|20456->12596|20485->12597|20521->12605|20576->12631|20606->12632|20644->12642|22001->13970|22031->13971|22074->13985|22165->14047|22195->14048|22242->14066|22396->14191|22426->14192|22456->14193|22490->14198|22520->14199|22567->14217|22711->14332|22741->14333|22784->14347|23097->14632|23127->14633|23161->14639|23190->14640|23226->14648|23346->14739|23376->14740|23414->14750|23521->14829|23550->14830|23584->14836|23640->14863|23670->14864|23708->14874|23814->14951|23844->14952|23887->14966|23933->14983|23963->14984|24010->15002|24062->15025|24092->15026|24122->15027|24173->15049|24203->15050|24250->15068|24303->15092|24333->15093|24376->15107|24433->15136|24463->15137|24493->15138|24527->15143|24557->15144|24600->15158|24646->15175|24676->15176|24723->15194|24776->15218|24806->15219|24836->15220|24887->15242|24917->15243|24964->15261|25016->15284|25046->15285|25089->15299|25146->15328|25176->15329|25210->15335|25239->15336|25273->15342|25329->15369|25359->15370|25397->15380|25517->15471|25547->15472|25590->15486|25678->15546|25708->15547|25738->15548|25772->15553|25802->15554|25845->15568|25933->15628|25963->15629|25997->15635|26026->15636|26062->15644|26236->15789|26266->15790|26304->15800|26361->15828|26391->15829|26434->15843|26611->15992|26641->15993|26677->16001|26706->16002|26746->16014|26852->16091|26882->16092|26920->16102|27010->16163|27040->16164|27083->16178|27379->16445|27409->16446|27456->16464|27613->16592|27643->16593|27694->16615|27777->16669|27807->16670|27837->16671|27871->16676|27901->16677|27952->16699|28034->16752|28064->16753|28107->16767|28137->16768|28167->16769|28201->16774|28231->16775|28278->16793|28454->16940|28484->16941|28529->16957|28704->17104|28734->17105|28770->17113|28799->17114|28966->17252|28996->17253|29034->17263|29121->17321|29151->17322|29194->17336|29569->17682|29599->17683|29646->17701|29828->17854|29858->17855|29896->17865|29926->17866|29962->17874|29991->17875|30057->17912|30087->17913|30125->17923|30261->18030|30291->18031|30334->18045|30370->18052|30400->18053|30447->18071|30608->18203|30638->18204|30689->18226|30783->18291|30813->18292|30886->18336|30916->18337|30967->18359|31070->18433|31100->18434|31143->18448|31173->18449|31213->18461|31243->18462|31277->18468|31306->18469|31338->18473|31406->18512|31436->18513|31474->18523|31529->18549|31559->18550|31602->18564|31694->18628|31724->18629|31760->18637|31789->18638|31840->18660|31870->18661|31908->18671|31964->18698|31994->18699|32037->18713|32129->18777|32159->18778|32195->18786|32224->18787|32275->18809|32305->18810|32343->18820|32399->18847|32429->18848|32472->18862|32564->18926|32594->18927|32630->18935|32659->18936 - LINES: 25->1|51->27|51->27|52->28|58->34|58->34|60->36|60->36|60->36|61->37|64->40|64->40|66->42|66->42|66->42|67->43|68->44|68->44|70->46|70->46|70->46|71->47|77->53|77->53|79->55|79->55|79->55|80->56|86->62|86->62|88->64|88->64|88->64|89->65|91->67|91->67|93->69|93->69|93->69|94->70|98->74|98->74|100->76|100->76|100->76|101->77|103->79|103->79|105->81|105->81|105->81|106->82|110->86|110->86|112->88|112->88|112->88|113->89|116->92|116->92|118->94|119->95|119->95|120->96|122->98|122->98|123->99|123->99|123->99|124->100|129->105|129->105|131->107|132->108|132->108|133->109|135->111|135->111|137->113|138->114|138->114|139->115|140->116|140->116|142->118|142->118|142->118|143->119|159->135|159->135|161->137|161->137|161->137|162->138|163->139|163->139|165->141|165->141|165->141|166->142|167->143|167->143|169->145|169->145|169->145|170->146|178->154|178->154|180->156|180->156|180->156|181->157|187->163|187->163|189->165|190->166|190->166|191->167|194->170|194->170|195->171|195->171|195->171|196->172|199->175|199->175|200->176|200->176|200->176|201->177|203->179|203->179|205->181|234->210|234->210|235->211|236->212|236->212|237->213|246->222|246->222|247->223|249->225|249->225|251->227|251->227|252->228|254->230|254->230|255->231|261->237|261->237|262->238|262->238|262->238|263->239|264->240|264->240|265->241|266->242|266->242|268->244|270->246|270->246|271->247|272->248|272->248|277->253|277->253|278->254|279->255|279->255|280->256|286->262|286->262|289->265|289->265|291->267|292->268|292->268|293->269|294->270|294->270|295->271|302->278|302->278|305->281|305->281|307->283|308->284|308->284|310->286|311->287|311->287|312->288|315->291|315->291|318->294|318->294|320->296|321->297|321->297|322->298|323->299|323->299|324->300|331->307|331->307|332->308|332->308|333->309|333->309|334->310|334->310|336->312|336->312|336->312|337->313|343->319|343->319|344->320|345->321|345->321|347->323|350->326|350->326|351->327|353->329|353->329|354->330|355->331|355->331|356->332|358->334|358->334|360->336|360->336|360->336|361->337|361->337|361->337|362->338|365->341|365->341|366->342|367->343|367->343|368->344|370->346|370->346|371->347|371->347|373->349|375->351|375->351|377->353|377->353|377->353|378->354|403->379|403->379|404->380|405->381|405->381|406->382|408->384|408->384|408->384|408->384|408->384|409->385|411->387|411->387|412->388|417->393|417->393|418->394|418->394|420->396|421->397|421->397|422->398|423->399|423->399|424->400|424->400|424->400|425->401|426->402|426->402|427->403|427->403|427->403|428->404|429->405|429->405|429->405|429->405|429->405|430->406|431->407|431->407|432->408|433->409|433->409|433->409|433->409|433->409|434->410|434->410|434->410|435->411|436->412|436->412|436->412|436->412|436->412|437->413|438->414|438->414|439->415|440->416|440->416|441->417|441->417|442->418|442->418|442->418|443->419|444->420|444->420|445->421|446->422|446->422|446->422|446->422|446->422|447->423|448->424|448->424|449->425|449->425|451->427|452->428|452->428|453->429|453->429|453->429|454->430|457->433|457->433|458->434|458->434|462->438|463->439|463->439|464->440|464->440|464->440|465->441|468->444|468->444|469->445|470->446|470->446|471->447|472->448|472->448|472->448|472->448|472->448|473->449|474->450|474->450|475->451|475->451|475->451|475->451|475->451|476->452|479->455|479->455|481->457|486->462|486->462|487->463|487->463|490->466|490->466|491->467|491->467|491->467|492->468|497->473|497->473|498->474|503->479|503->479|504->480|504->480|505->481|505->481|507->483|507->483|508->484|511->487|511->487|512->488|512->488|512->488|513->489|515->491|515->491|516->492|517->493|517->493|518->494|518->494|519->495|521->497|521->497|522->498|522->498|523->499|523->499|524->500|524->500|526->502|528->504|528->504|529->505|529->505|529->505|530->506|532->508|532->508|533->509|533->509|534->510|534->510|535->511|535->511|535->511|536->512|538->514|538->514|539->515|539->515|540->516|540->516|541->517|541->517|541->517|542->518|544->520|544->520|545->521|545->521 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index aae674ab4..804c6f974 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -47,13 +47,13 @@ import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; diff --git a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java index 04773ee74..4c4d2844f 100644 --- a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java +++ b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java @@ -71,7 +71,7 @@ public class DL4JSystemProperties { public static final String CRASH_DUMP_OUTPUT_DIRECTORY_PROPERTY = "org.deeplearning4j.crash.reporting.directory"; /** - * Applicability: deeplearning4j-ui_2.xx
+ * Applicability: deeplearning4j-ui
* Description: The DL4J training UI (StatsListener + UIServer.getInstance().attach(ss)) will subsample the number * of chart points when a lot of data is present - i.e., only a maximum number of points will be shown on each chart. * This is to reduce the UI bandwidth requirements and client-side rendering cost. @@ -81,7 +81,7 @@ public class DL4JSystemProperties { /** - * Applicability: deeplearning4j-play (deeplearning4j-ui_2.xx)
+ * Applicability: deeplearning4j-vertx (deeplearning4j-ui)
* Description: This property sets the port that the UI will be available on. Default port: 9000. * Set to 0 for a random port. */ diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 75a0ad759..d506bb233 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -142,7 +142,7 @@ public class GradientCheckTests extends BaseDL4JTest { // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) - Activation[] activFns = {Activation.SIGMOID, Activation.TANH}; + Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.MISH}; boolean[] characteristic = {false, true}; //If true: run some backprop steps first LossFunction[] lossFunctions = {LossFunction.MCXENT, LossFunction.MSE}; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 56d53d07a..fa06ff8f7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), new LossMultiLabel(), new LossWasserstein(), + new LossSparseMCXENT() }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { Activation.IDENTITY, // MixtureDensity Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper - Activation.IDENTITY // Wasserstein + Activation.IDENTITY, // Wasserstein + Activation.SOFTMAX, //sparse MCXENT }; int[] nOut = new int[] {1, //xent @@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { 10, // Mixture Density + tanh 10, // MultiLabel 2, // Wasserstein + 4, //sparse MCXENT }; int[] minibatchSizes = new int[] {1, 3}; @@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(), new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), - new LossMultiLabel(), new LossWasserstein() + new LossMultiLabel(), new LossWasserstein(), + new LossSparseMCXENT() }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { Activation.IDENTITY, // MixtureDensity Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MultiLabel - Activation.IDENTITY // Wasserstein + Activation.IDENTITY, // Wasserstein + Activation.SOFTMAX }; int[] nOut = new int[] {1, //xent @@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { 10, // Mixture Density + tanh 10, // MultiLabel 2, // Wasserstein + 4 }; int[] minibatchSizes = new int[] {1, 3}; @@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { throw new UnsupportedOperationException(); } + break; + case "LossSparseMCXENT": + if (labelsShape.length == 2) { + ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1); + for (int i = 0; i < labelsShape[0]; i++) { + ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1])); + } + } else if (labelsShape.length == 3) { + ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]); + for (int i = 0; i < labelsShape[0]; i++) { + for (int j = 0; j < labelsShape[2]; j++) { + ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1])); + } + } + } else { + throw new UnsupportedOperationException(); + } break; case "LossHinge": case "LossSquaredHinge": diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 2552b6072..32a229101 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.nd4j.linalg.lossfunctions.impl.LossMSE; +import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import java.util.Random; @@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { int nOut = 2; int miniBatchSize = 3; - ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()}; + ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()}; for (int maskType = 0; maskType < 3; maskType++) { Random r = new Random(12345L); INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}); - INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength); - for (int i = 0; i < miniBatchSize; i++) { - for (int j = 0; j < timeSeriesLength; j++) { - int idx = r.nextInt(nOut); - labels.putScalar(new int[]{i, idx, j}, 1.0); - } - } INDArray labelMask; String mt; @@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { break; case 1: //Per time step masking - labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength); + labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerTimeStep"; break; case 2: //Per output masking: - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength}); + labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerOutput"; break; @@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { for (ILossFunction lf : lfs) { + INDArray labels; + if(lf instanceof LossSparseMCXENT){ + labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength); + for (int i = 0; i < miniBatchSize; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + int idx = r.nextInt(nOut); + labels.putScalar(new int[]{i, 0, j}, idx); + } + } + } else { + labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength); + for (int i = 0; i < miniBatchSize; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + int idx = r.nextInt(nOut); + labels.putScalar(new int[]{i, idx, j}, 1.0); + } + } + } + + Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX; MultiLayerConfiguration conf = diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java index 6a5e1ff2a..4035e9298 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java @@ -75,7 +75,6 @@ public class KerasReshape extends KerasLayer { List targetShapeList = (List) innerConfig.get(targetShape); this.targetShape = listToLongArray(targetShapeList); } - } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index 77c6369c5..e9aef5b90 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong; @Data @Slf4j @EqualsAndHashCode(callSuper = false) -@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"}) +@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"}) public class ReshapePreprocessor extends BaseInputPreProcessor { private long[] inputShape; private long[] targetShape; private boolean hasMiniBatchDimension = false; private int miniBatchSize; + private long[] staticTargetShape; public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { this.inputShape = inputShape; @@ -80,10 +81,19 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { // the target shape read from a keras config does not have mini-batch size // included. We prepend it here dynamically. + + long[] targetShape; + if (staticTargetShape != null){ + targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize); + hasMiniBatchDimension = true; + this.miniBatchSize = miniBatchSize; + } + else{ + targetShape = this.targetShape; + } if (!this.hasMiniBatchDimension) { targetShape = prependMiniBatchSize(targetShape, miniBatchSize); inputShape = prependMiniBatchSize(inputShape, miniBatchSize); - this.hasMiniBatchDimension = true; this.miniBatchSize = miniBatchSize; } if (this.miniBatchSize != miniBatchSize) { @@ -95,7 +105,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); } - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape)); + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape)); } else { throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) + " and output shape" + Arrays.toString(inputShape) + " do not match"); @@ -122,20 +132,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { @Override public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); + InputType ret; switch (shape.length) { case 2: - return InputType.feedForward(shape[1]); + ret = InputType.feedForward(shape[1]); + break; case 3: - return InputType.recurrent(shape[2], shape[1]); + ret = InputType.recurrent(shape[2], shape[1]); + break; case 4: if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ - return InputType.convolutional(shape[1], shape[2], shape[3]); + ret = InputType.convolutional(shape[1], shape[2], shape[3]); }else { - return InputType.convolutional(shape[2], shape[3], shape[1]); + ret = InputType.convolutional(shape[2], shape[3], shape[1]); } + break; default: throw new UnsupportedOperationException( "Cannot infer input type for reshape array " + Arrays.toString(shape)); + } + this.staticTargetShape = ret.getShape(); + return ret; } } \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java index 1149401c6..35cf34170 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java @@ -54,7 +54,7 @@ public class KerasLossUtils { } else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) { dl4jLoss = LossFunctions.LossFunction.HINGE; } else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) { - throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet."); + dl4jLoss = LossFunctions.LossFunction.SPARSE_MCXENT; } else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) { dl4jLoss = LossFunctions.LossFunction.XENT; } else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 81103d315..db03128f7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { } } + @Test + public void ReshapeEmbeddingConcatTest() throws Exception{ + try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { + ComputationGraphConfiguration config = + new KerasModel().modelBuilder().modelJsonInputStream(is) + .enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration(); + ComputationGraph model = new ComputationGraph(config); + model.init(); + model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1)); + } + } private void runSequentialConfigTest(String path) throws Exception { try(InputStream is = Resources.asStream(path)) { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index b33ff8d1f..0565cc091 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -49,6 +49,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import org.nd4j.resources.Resources; import java.io.File; @@ -88,7 +89,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Test(expected = IllegalStateException.class) public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; - importEndModelTest(modelPath, null, true, true, false); + importEndModelTest(modelPath, null, true, true, false, false); } /** @@ -98,28 +99,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importMnistMlpTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importMnistMlpThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); } @Test public void importMnistMlpTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importMnistMlpReshapeTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true); + importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); } /** @@ -129,21 +130,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importMnistCnnTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); } @Test public void importMnistCnnThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, true); + importEndModelTest(modelPath, inputsOutputPath, false, true, true, false); } @Test public void importMnistCnnTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true); + importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); } /** @@ -153,28 +154,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importImdbLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importImdbLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importImdbLstmThKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); } /** @@ -185,21 +186,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbFasttextTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, false, false); + importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); } @Test public void importImdbFasttextThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, false, false); + importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); } @Test public void importImdbFasttextTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); } /** @@ -209,21 +210,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importSimpleLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importSimpleLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importSimpleLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); } @@ -235,7 +236,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } /** @@ -246,7 +247,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } /** @@ -257,7 +258,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + "simple_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } /** @@ -267,7 +268,18 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importCnnNoBiasTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, true); + importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); + } + + @Test + public void importSparseXent() throws Exception { + String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5"; + String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5"; + MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true); + Layer outLayer = net.getOutputLayer(); + assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer); + LossLayer llConf = (LossLayer) outLayer.getConfig(); + assertEquals(new LossSparseMCXENT(), llConf.getLossFn()); } /** @@ -626,19 +638,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } } - - private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions) throws Exception { - importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false); - } - - public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, - boolean checkGradients) throws Exception { + public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, + boolean checkGradients, boolean enforceTrainingConfig) throws Exception { MultiLayerNetwork model; try(InputStream is = Resources.asStream(modelPath)) { File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) - .enforceTrainingConfig(false).buildSequential(); + .enforceTrainingConfig(enforceTrainingConfig).buildSequential(); model = kerasModel.getMultiLayerNetwork(); } @@ -699,6 +706,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { checkGradients(model, input, testLabels); } } + + return model; } private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 38ee4204c..66c639f56 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -34,16 +34,6 @@ 2.11 - - - maven-compiler-plugin - - 1.8 - 1.8 - - - - @@ -56,9 +46,14 @@ *.java **/*.java - - ../../deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/application.conf - + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml index 44fbbcf9d..619d36db6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml @@ -68,7 +68,7 @@ org.deeplearning4j - deeplearning4j-ui_2.11 + deeplearning4j-ui ${project.version} test diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java index c6a88ffb3..b5ca6c91a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -79,6 +80,17 @@ import java.util.Map; * .build(); * } * + *
+ * Example to use an instantiated iterator for inference:
+ *
+ * {@code
+ *          BertIterator b;
+ *          List forInference;
+ *          Pair featuresAndMask = b.featurizeSentences(forInference);
+ *          INDArray[] features = featuresAndMask.getFirst();
+ *          INDArray[] featureMasks = featuresAndMask.getSecond();
+ * }
+ * 
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
*
* {@link LengthHandling} configuration:
@@ -89,7 +101,7 @@ import java.util.Map; * CLIP_ONLY: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in * a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the * maximum (within the current minibatch) they will be zero padded and masked.
- *

+ *

* {@link FeatureArrays} configuration:
* Determines what arrays should be included.
* INDICES_MASK: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).
@@ -107,8 +119,11 @@ import java.util.Map; public class BertIterator implements MultiDataSetIterator { public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION} + public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY} + public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID} + public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC} protected Task task; @@ -116,12 +131,13 @@ public class BertIterator implements MultiDataSetIterator { protected int maxTokens = -1; protected int minibatchSize = 32; protected boolean padMinibatches = false; - @Getter @Setter + @Getter + @Setter protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; protected LengthHandling lengthHandling; protected FeatureArrays featureArrays; - protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? + protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? protected BertSequenceMasker masker = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected String maskToken; @@ -130,7 +146,7 @@ public class BertIterator implements MultiDataSetIterator { protected List vocabKeysAsList; - protected BertIterator(Builder b){ + protected BertIterator(Builder b) { this.task = b.task; this.tokenizerFactory = b.tokenizerFactory; this.maxTokens = b.maxTokens; @@ -166,10 +182,10 @@ public class BertIterator implements MultiDataSetIterator { public MultiDataSet next(int num) { Preconditions.checkState(hasNext(), "No next element available"); - List> list = new ArrayList<>(num); - int count = 0; - if(sentenceProvider != null){ - while(sentenceProvider.hasNext() && count++ < num) { + List> list = new ArrayList<>(num); + int mbSize = 0; + if (sentenceProvider != null) { + while (sentenceProvider.hasNext() && mbSize++ < num) { list.add(sentenceProvider.nextSentence()); } } else { @@ -177,41 +193,60 @@ public class BertIterator implements MultiDataSetIterator { throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); } - //Get and tokenize the sentences for this minibatch - List, String>> tokenizedSentences = new ArrayList<>(num); - int longestSeq = -1; - for(Pair p : list){ - List tokens = tokenizeSentence(p.getFirst()); - tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); - longestSeq = Math.max(longestSeq, tokens.size()); - } - //Determine output array length... - int outLength; - switch (lengthHandling){ - case FIXED_LENGTH: - outLength = maxTokens; - break; - case ANY_LENGTH: - outLength = longestSeq; - break; - case CLIP_ONLY: - outLength = Math.min(maxTokens, longestSeq); - break; - default: - throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); - } + Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list); + List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); + int outLength = outLTokenizedSentencesPair.getLeft(); - int mb = tokenizedSentences.size(); - int mbPadded = padMinibatches ? minibatchSize : mb; + Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength); + INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); + INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); + + + Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); + INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); + INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); + + org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray); + if (preProcessor != null) + preProcessor.preProcess(mds); + + return mds; + } + + + /** + * For use during inference. Will convert a given list of sentences to features and feature masks as appropriate. + * + * @param listOnlySentences + * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array + */ + public Pair featurizeSentences(List listOnlySentences) { + + List> sentencesWithNullLabel = addDummyLabel(listOnlySentences); + + Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel); + List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); + int outLength = outLTokenizedSentencesPair.getLeft(); + + Pair featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength); + if (preProcessor != null) { + MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); + preProcessor.preProcess(dummyMDS); + return new Pair(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); + } + return convertMiniBatchFeatures(tokenizedSentences, outLength); + } + + private Pair convertMiniBatchFeatures(List, String>> tokenizedSentences, int outLength) { + int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); int[][] outIdxs = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength]; - - for( int i=0; i,String> p = tokenizedSentences.get(i); + for (int i = 0; i < tokenizedSentences.size(); i++) { + Pair, String> p = tokenizedSentences.get(i); List t = p.getFirst(); - for( int j=0; j(f, fm); + } + private Pair, String>>> tokenizeMiniBatch(List> list) { + //Get and tokenize the sentences for this minibatch + List, String>> tokenizedSentences = new ArrayList<>(list.size()); + int longestSeq = -1; + for (Pair p : list) { + List tokens = tokenizeSentence(p.getFirst()); + tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); + longestSeq = Math.max(longestSeq, tokens.size()); + } + + //Determine output array length... + int outLength; + switch (lengthHandling) { + case FIXED_LENGTH: + outLength = maxTokens; + break; + case ANY_LENGTH: + outLength = longestSeq; + break; + case CLIP_ONLY: + outLength = Math.min(maxTokens, longestSeq); + break; + default: + throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); + } + return new Pair<>(outLength, tokenizedSentences); + } + + private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { INDArray[] l = new INDArray[1]; INDArray[] lm; - if(task == Task.SEQ_CLASSIFICATION){ + int mbSize = tokenizedSentences.size(); + int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); + if (task == Task.SEQ_CLASSIFICATION) { //Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses] int numClasses; int[] classLabels = new int[mbPadded]; - if(sentenceProvider != null){ + if (sentenceProvider != null) { numClasses = sentenceProvider.numLabelClasses(); List labels = sentenceProvider.allLabels(); - for(int i=0; i= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); @@ -252,21 +320,21 @@ public class BertIterator implements MultiDataSetIterator { throw new RuntimeException(); } l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses); - for( int i=0; i e : vocabMap.entrySet()){ + for (Map.Entry e : vocabMap.entrySet()) { arr[e.getValue()] = e.getKey(); } vocabKeysAsList = Arrays.asList(arr); @@ -276,31 +344,31 @@ public class BertIterator implements MultiDataSetIterator { int vocabSize = vocabMap.size(); INDArray labelArr; INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength); - if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ + if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) { labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); - } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ + } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) { labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength); - } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ + } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) { labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize); } else { throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); } - for( int i=0; i tokens = tokenizedSentences.get(i).getFirst(); - Pair,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); + Pair, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); List maskedTokens = p.getFirst(); boolean[] predictionTarget = p.getSecond(); int seqLen = Math.min(predictionTarget.length, outLength); - for(int j=0; j(l, lm); } private List tokenizeSentence(String sentence) { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); - if(prependToken != null) + if (prependToken != null) tokens.add(prependToken); while (t.hasMoreTokens()) { @@ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator { return tokens; } + + private List> addDummyLabel(List listOnlySentences) { + List> list = new ArrayList<>(listOnlySentences.size()); + for (String s : listOnlySentences) { + list.add(new Pair(s, null)); + } + return list; + } + + @Override public boolean resetSupported() { return true; @@ -353,12 +427,12 @@ public class BertIterator implements MultiDataSetIterator { @Override public void reset() { - if(sentenceProvider != null){ + if (sentenceProvider != null) { sentenceProvider.reset(); } } - public static Builder builder(){ + public static Builder builder() { return new Builder(); } @@ -373,7 +447,7 @@ public class BertIterator implements MultiDataSetIterator { protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; - protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? + protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected String maskToken; @@ -382,7 +456,7 @@ public class BertIterator implements MultiDataSetIterator { /** * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. */ - public Builder task(Task task){ + public Builder task(Task task) { this.task = task; return this; } @@ -392,18 +466,19 @@ public class BertIterator implements MultiDataSetIterator { * For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory} * is used */ - public Builder tokenizer(TokenizerFactory tokenizerFactory){ + public Builder tokenizer(TokenizerFactory tokenizerFactory) { this.tokenizerFactory = tokenizerFactory; return this; } /** * Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details. - * @param lengthHandling Length handling - * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} + * + * @param lengthHandling Length handling + * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} * @return */ - public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){ + public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) { this.lengthHandling = lengthHandling; this.maxTokens = maxLength; return this; @@ -412,9 +487,10 @@ public class BertIterator implements MultiDataSetIterator { /** * Minibatch size to use (number of examples to train on for each iteration) * See also: {@link #padMinibatches} - * @param minibatchSize Minibatch size + * + * @param minibatchSize Minibatch size */ - public Builder minibatchSize(int minibatchSize){ + public Builder minibatchSize(int minibatchSize) { this.minibatchSize = minibatchSize; return this; } @@ -429,7 +505,7 @@ public class BertIterator implements MultiDataSetIterator { * Both options should result in exactly the same model. However, some BERT implementations may require exactly an * exact number of examples in all minibatches to function. */ - public Builder padMinibatches(boolean padMinibatches){ + public Builder padMinibatches(boolean padMinibatches) { this.padMinibatches = padMinibatches; return this; } @@ -437,7 +513,7 @@ public class BertIterator implements MultiDataSetIterator { /** * Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null) */ - public Builder preProcessor(MultiDataSetPreProcessor preProcessor){ + public Builder preProcessor(MultiDataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; return this; } @@ -446,7 +522,7 @@ public class BertIterator implements MultiDataSetIterator { * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised * use case, the labels will be ignored. */ - public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){ + public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { this.sentenceProvider = sentenceProvider; return this; } @@ -454,7 +530,7 @@ public class BertIterator implements MultiDataSetIterator { /** * Specify what arrays should be returned. See {@link BertIterator} for more details. */ - public Builder featureArrays(FeatureArrays featureArrays){ + public Builder featureArrays(FeatureArrays featureArrays) { this.featureArrays = featureArrays; return this; } @@ -465,7 +541,7 @@ public class BertIterator implements MultiDataSetIterator { * If using {@link BertWordPieceTokenizerFactory}, * this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()} */ - public Builder vocabMap(Map vocabMap){ + public Builder vocabMap(Map vocabMap) { this.vocabMap = vocabMap; return this; } @@ -475,7 +551,7 @@ public class BertIterator implements MultiDataSetIterator { * masked language model. This can be used to customize how the masking is performed.
* Default: {@link BertMaskedLMMasker} */ - public Builder masker(BertSequenceMasker masker){ + public Builder masker(BertSequenceMasker masker) { this.masker = masker; return this; } @@ -485,7 +561,7 @@ public class BertIterator implements MultiDataSetIterator { * masked language model. Used to specify the format that the labels should be returned in. * See {@link BertIterator} for more details. */ - public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){ + public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) { this.unsupervisedLabelFormat = labelFormat; return this; } @@ -497,7 +573,7 @@ public class BertIterator implements MultiDataSetIterator { * the exact behaviour will depend on what masker is used.
* Note that this must be in the vocabulary map set in {@link #vocabMap} */ - public Builder maskToken(String maskToken){ + public Builder maskToken(String maskToken) { this.maskToken = maskToken; return this; } @@ -510,12 +586,12 @@ public class BertIterator implements MultiDataSetIterator { * * @param prependToken The token to start each sequence with (null: no token will be prepended) */ - public Builder prependToken(String prependToken){ + public Builder prependToken(String prependToken) { this.prependToken = prependToken; return this; } - public BertIterator build(){ + public BertIterator build() { Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map) to set"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index 90879f858..d4be5e352 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.primitives.Pair; import org.nd4j.resources.Resources; @@ -34,10 +33,7 @@ import java.io.File; import java.io.IOException; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; import static org.junit.Assert.*; @@ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest { String toTokenize1 = "I saw a girl with a telescope."; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List forInference = new ArrayList<>(); + forInference.add(toTokenize1); + forInference.add(toTokenize2); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertIterator b = BertIterator.builder() @@ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expF, mds.getFeatures(0)); assertEquals(expM, mds.getFeaturesMaskArray(0)); + assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]); + assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]); assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); MultiDataSet mds2 = b.next(); + forInference.set(0,toTokenize2); //Same thing, but with segment ID also b = BertIterator.builder() .tokenizer(t) @@ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest { .build(); mds = b.next(); assertEquals(2, mds.getFeatures().length); + assertEquals(2,b.featurizeSentences(forInference).getFirst().length); //Segment ID should be all 0s for single segment task INDArray segmentId = expM.like(); assertEquals(segmentId, mds.getFeatures(1)); + assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]); } @Test(timeout = 20000L) @@ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest { public void testLengthHandling() throws Exception { String toTokenize1 = "I saw a girl with a telescope."; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List forInference = new ArrayList<>(); + forInference.add(toTokenize1); + forInference.add(toTokenize2); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); @@ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest { assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0)); + assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]); + assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); //Clip only: clip to maximum, but don't pad if less b = BertIterator.builder() @@ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); String toTokenize1 = "I saw a girl with a telescope."; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List forInference = new ArrayList<>(); + forInference.add(toTokenize1); + forInference.add(toTokenize2); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); @@ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expL, mds.getLabels(0)); assertEquals(expLM, mds.getLabelsMaskArray(0)); + + assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); } private static class TestSentenceProvider implements LabeledSentenceProvider { diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml index eaf0e6be8..f95b9935b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml @@ -35,15 +35,18 @@ - - - maven-compiler-plugin - - 1.8 - 1.8 - - - + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml index e647445e9..e3b6444bf 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml @@ -28,16 +28,18 @@ deeplearning4j-parallel-wrapper - - - org.apache.maven.plugins - maven-compiler-plugin - - 1.8 - 1.8 - - - + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + UTF-8 @@ -84,7 +86,7 @@ org.deeplearning4j - deeplearning4j-play_2.11 + deeplearning4j-ui ${project.version} test diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index 9192bb877..f22f2f6b8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -79,15 +79,18 @@ - - - maven-compiler-plugin - - 1.8 - 1.8 - - - + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index 488229fa5..5cf290c31 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -67,16 +67,9 @@ org.deeplearning4j - deeplearning4j-play_2.11 + deeplearning4j-ui ${deeplearning4j.version} test - - - - net.jpountz.lz4 - lz4 - - diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/.gitignore b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/.gitignore deleted file mode 100644 index 2c1cc352a..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/.gitignore +++ /dev/null @@ -1,17 +0,0 @@ -logs -project/project -project/target -target -tmp -.history -dist -/.idea -/*.iml -/out -/.idea_modules -/.classpath -/.project -/RUNNING_PID -/.settings - - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/developerReadme.md b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/developerReadme.md deleted file mode 100644 index 47ae877b6..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/developerReadme.md +++ /dev/null @@ -1,88 +0,0 @@ - -# Deeplearning4j UI (deeplearning4j-play) Developer Readme - -The deeplearning4j UI utilizes an embedded Play server for hosting web pages. This guide documents some basic information -about the design, and how new pages can be added. - -## Overview - -As of DL4J 0.7.2, version 2.4 is utilized for scala 2.10 compatibility, which was dropped in Play 2.5 and later. -Due to DL4J being managed by Maven, not all features of Play (such as conf/routes approach to routing configuration, and -the default internationalization mechanism) are supported. - -Some key concepts: - -- The actual HTML pages are defined in the main/views/org.deeplearning4j.ui.views.* directory - - These views are based on Twirl template engine: in practice, templates are mostly HTML with a - little scala code mixed in - - The views themselves need to be rendered to a scala class (under main/scala). This is currently done manually, - see the 'building templates' section. -- Routing is somewhat non-standard in DL4J, compared to a typical Play application - - The DL4J has the concept of a 'UI module': a related set of pages to serve via the UI - - Each module implements the UIModule interface, which defines a set of routes, and how to handle page requests - - Custom modules are supported: this is not enabled by default. See PlayUIServer UI_CUSTOM_MODULE_PROPERTY for details. -- Internationalization (i.e., multiple languages) is supported, but not using the standard Play mechanisms -- The main class and entry point is PlayUIServer - -## Building Templates - -The current setup (using Maven + Play) means that the HTML templates need to be built manually. -Just run ```mvn compile -P buildUiTemplates``` on the command line to do this - this will generate the Scala classes -from all of the HTML files (in the views directory) and copy them to the appropriate location in the scala directory. - -## Adding New Pages - -Adding a new (built-in) page to the UI can be done using the following approach: - -- Define a new module in org.deeplearning4j.ui.module, that implements the UIModule interface - - Provide routes: these are the endpoints that the user will be able to query, and define what methods will be used - to get the result to return for that page. See for example TrainModule for details - - Each route needs to specify the method to call to get the result (must return a play.mvc.Result object) - - Supplier: used to return results with no args; Function: 1 arg; BiFunction and Function3: 2 and 3 args respectively. - For function/bifunction etc, the arguments are specified with semicolons, like "/myRoute/:myArg"; the called - method should have an appropriate number of arguments to match this. - - Optionally: add code to handle callbacks, storage events, stats storage. See the section below. -- Add the module to the others, in the PlayUIServer. Should be 1 line: ```uiModules.add(new MyNewModule());``` - -## Assets - -Again (due to issues with Maven + Play support), DL4J defines a custom Assets serving system. Assets are static files -such as images, CSS, javascript files, etc. -The contents of /resources/deeplearning4jUiAssets/* get mapped to /assets/* on the running server. To add a new asset, -simply add it to the assets folder, and then reference it as normal in the HTML templates or elsewhere. - -## StatsStorage, Callbacks, and Getting Network and Other Info from DL4J - -The Deeplearning4j UI supports callbacks from StatsStorage. StatsStorage instances are objects/interfaces that store -training information - some of it may be static/stored, some of it may be streaming in, in real-time. - -When a user calls ```UIServer.getInstance().attach(StatsStorage)``` the provided StatsStorage instance will provide -callbacks to the UI whenever something changes. For example, new information from a trained network is added to the -StatsStorage from the StatsListener, the modules that are registered for callbacks (of that type) will be notified. -The UI modules can then query the StatStorage instance (or not) to get the relevant information for displaying in the UI. - -Each UI module specifies the callbacks it wants to receive via Strings (UIModule.getCallbackTypeIDs). Each String is a -key that must match the TypeID of the data to listen for; consequently, there is a correspondence between the TypeID of -the generating class (StatsListener, for example) and the TypeID that the UI module wants to receive. UI modules may -specify zero or more callback type IDs. Any information that is not relevant to the UI module (i.e., doesn't match a -specified TypeID for the module) won't be forwarded on to the UI module. - - -## Internationalization - -The Play framework does provide an internationalization mechanism, though this was found to be problematic when running - Play via a Maven project. Consequently, DL4J defines its own I18N implementation, that is functionally similar to the - Play mechanism. - -The I18N interface, and the DefaultI18N class define getMessage(String) methods. For example, getMessage() -Currently (as of 0.7.2), the language setting is *server side* not client side - thus only one language can be shown per - UI server. Furthermore, the implementation is currently set to fall back to English. - -The actual content for internationalization is present under the resources/dl4j_i18n directory. Each file within this - directory should contain an ISO639 language code (en, ja, de etc) as the extension. -Conceptually, the files may be separated; in practice, the contents of all files (for the relevant language) are -In practice, just add a snippet such as ```@i18n.getMessage("train.pagetitle")``` to the HTML template to get the -appropriate entry for the current language. See the train module UI pages for more examples. - -Note also that it is necessary to provide an I18N instance to the templates. See the TrainModule routing section -for an example on how to do this. diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java deleted file mode 100644 index f2dd1e017..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.api; - -import lombok.AllArgsConstructor; -import lombok.Data; -import play.mvc.Http; -import play.mvc.Result; - -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * A Route specifies an endpoint that can be queried in the UI - along with how it should be handled - * - * @author Alex Black - */ -@Data -@AllArgsConstructor -public class Route { - private final String route; - private final HttpMethod httpMethod; - private final FunctionType functionType; - private final Supplier supplier; - private final Function function; - private final BiFunction function2; - private final Function request0Function; - private final BiFunction request1Function; - - public Route(String route, HttpMethod method, FunctionType functionType, Supplier supplier) { - this(route, method, functionType, supplier, null, null, null, null); - } - - public Route(String route, HttpMethod method, FunctionType functionType, Function function) { - this(route, method, functionType, null, function, null, null, null); - } - - public static Route request0Function(String route, HttpMethod httpMethod, Function function){ - return new Route(route, httpMethod, FunctionType.Request0Function, null, null, null, function, null); - } - - public static Route request1Function(String route, HttpMethod httpMethod, BiFunction function){ - return new Route(route, httpMethod, FunctionType.Request1Function, null, null, null, null, function); - } - - public Route(String route, HttpMethod method, FunctionType functionType, - BiFunction function) { - this(route, method, functionType, null, null, function, null, null); - } -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java deleted file mode 100644 index 0886af8a9..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java +++ /dev/null @@ -1,163 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.module.tsne; - -import org.apache.commons.io.FileUtils; -import org.deeplearning4j.api.storage.StatsStorage; -import org.deeplearning4j.api.storage.StatsStorageEvent; -import org.deeplearning4j.ui.api.FunctionType; -import org.deeplearning4j.ui.api.HttpMethod; -import org.deeplearning4j.ui.api.Route; -import org.deeplearning4j.ui.api.UIModule; -import org.deeplearning4j.ui.i18n.I18NResource; -import play.libs.Files; -import play.libs.Json; -import play.mvc.Http; -import play.mvc.Result; -import play.mvc.Results; - -import java.io.File; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.*; - -import static play.mvc.Results.badRequest; -import static play.mvc.Results.ok; - -/** - * Created by Alex on 25/10/2016. - */ -public class TsneModule implements UIModule { - - private static final String UPLOADED_FILE = "UploadedFile"; - - private Map> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); - - private List uploadedFileLines = null; - - public TsneModule() { - - } - - @Override - public List getCallbackTypeIDs() { - return Collections.emptyList(); - } - - @Override - public List getRoutes() { - Route r1 = new Route("/tsne", HttpMethod.GET, FunctionType.Supplier, - () -> ok(org.deeplearning4j.ui.views.html.tsne.Tsne.apply())); - Route r2 = new Route("/tsne/sessions", HttpMethod.GET, FunctionType.Supplier, this::listSessions); - Route r3 = new Route("/tsne/coords/:sid", HttpMethod.GET, FunctionType.Function, this::getCoords); - Route r4 = Route.request0Function("/tsne/upload", HttpMethod.POST, this::uploadFile); - Route r5 = Route.request1Function("/tsne/post/:sid", HttpMethod.POST, this::postFile); - return Arrays.asList(r1, r2, r3, r4, r5); - } - - @Override - public void reportStorageEvents(Collection events) { - - } - - @Override - public void onAttach(StatsStorage statsStorage) { - - } - - @Override - public void onDetach(StatsStorage statsStorage) { - - } - - @Override - public List getInternationalizationResources() { - return Collections.emptyList(); - } - - private Result listSessions() { - List list = new ArrayList<>(knownSessionIDs.keySet()); - if (uploadedFileLines != null) { - list.add(UPLOADED_FILE); - } - return Results.ok(Json.toJson(list)); - } - - private Result getCoords(String sessionId) { - if (UPLOADED_FILE.equals(sessionId) && uploadedFileLines != null) { - return Results.ok(Json.toJson(uploadedFileLines)); - } else if (knownSessionIDs.containsKey(sessionId)) { - return Results.ok(Json.toJson(knownSessionIDs.get(sessionId))); - } else { - return Results.ok(); - } - } - - private Result uploadFile(Http.Request request) { - Http.MultipartFormData body = request.body().asMultipartFormData(); - List> fileParts = body.getFiles(); - - if (fileParts.isEmpty()) { - return badRequest("No file uploaded"); - } - - Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); - - String fileName = uploadedFile.getFilename(); - String contentType = uploadedFile.getContentType(); - File file = uploadedFile.getRef().path().toFile(); - - try { - uploadedFileLines = FileUtils.readLines(file, StandardCharsets.UTF_8); - } catch (IOException e) { - return badRequest("Could not read from uploaded file"); - } - - return ok("File uploaded: " + fileName + ", " + contentType + ", " + file); - } - - private Result postFile(Http.Request request, String sid) { - // System.out.println("POST FILE CALLED: " + sid); - Http.MultipartFormData body = request.body().asMultipartFormData(); - List> fileParts = body.getFiles(); - - if (fileParts.isEmpty()) { - // System.out.println("**** NO FILE ****"); - return badRequest("No file uploaded"); - } - - Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); - - String fileName = uploadedFile.getFilename(); - String contentType = uploadedFile.getContentType(); - File file = uploadedFile.getRef().path().toFile(); - - List lines; - try { - // Set to uploadedFileLines as well, as the TSNE UI doesn't allow to properly select Sessions yet - lines = uploadedFileLines = FileUtils.readLines(file); - } catch (IOException e) { - // System.out.println("**** COULD NOT READ FILE ****"); - return badRequest("Could not read from uploaded file"); - } - - knownSessionIDs.put(sid, lines); - - - return ok("File uploaded: " + fileName + ", " + contentType + ", " + file); - } -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java deleted file mode 100644 index 76d6e7361..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.play.staticroutes; - -import org.nd4j.shade.guava.net.HttpHeaders; -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FilenameUtils; -import org.nd4j.linalg.io.ClassPathResource; -import play.mvc.Result; -import play.mvc.StaticFileMimeTypes; - -import java.io.InputStream; -import java.util.Optional; - -import static play.mvc.Results.ok; - -/** - * Simple function for serving assets. Assets are assumed to be in the specified root directory - * - * @author Alex Black - */ -@AllArgsConstructor -@Slf4j -public class Assets { - - public static Result assetRequest(String assetsRootDirectory, String s) { - - String fullPath; - if(s.startsWith("webjars/")){ - fullPath = "META-INF/resources/" + s; - } else { - fullPath = assetsRootDirectory + s; - } - - InputStream inputStream; - try { - inputStream = new ClassPathResource(fullPath).getInputStream(); - } catch (Throwable t) { - log.warn("Could not find requested UI asset: {}", s, t); - return ok(); - } - - String fileName = FilenameUtils.getName(fullPath); - - Optional contentType = StaticFileMimeTypes.fileMimeTypes().forFileName(fileName); - String ct; - ct = contentType.orElse("application/octet-stream"); - - return ok(inputStream) - .withHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"") - .as(ct); - } -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/MultiSessionI18NRoute.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/MultiSessionI18NRoute.java deleted file mode 100644 index ccbef0683..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/MultiSessionI18NRoute.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.play.staticroutes; - -import org.deeplearning4j.ui.i18n.I18NProvider; -import play.mvc.Result; - -import java.util.function.BiFunction; - -import static play.mvc.Results.ok; - -/** - * Route for multi-session internationalization setting - * - * @author Tamas Fenyvesi - */ -public class MultiSessionI18NRoute implements BiFunction { - - @Override - public Result apply(String sessionId, String languageCode) { - I18NProvider.getInstance(sessionId).setDefaultLanguage(languageCode); - return ok(); - } -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/application.conf b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/application.conf deleted file mode 100644 index 28a4aa208..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/application.conf +++ /dev/null @@ -1,350 +0,0 @@ -# This is the main configuration file for the application. -# https://www.playframework.com/documentation/latest/ConfigFile -# ~~~~~ -# Play uses HOCON as its configuration file format. HOCON has a number -# of advantages over other config formats, but there are two things that -# can be used when modifying settings. -# -# You can include other configuration files in this main application.conf file: -#include "extra-config.conf" -# -# You can declare variables and substitute for them: -#mykey = ${some.value} -# -# And if an environment variable exists when there is no other subsitution, then -# HOCON will fall back to substituting environment variable: -#mykey = ${JAVA_HOME} - -## Akka -# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration -# https://www.playframework.com/documentation/latest/JavaAkka#Configuration -# ~~~~~ -# Play uses Akka internally and exposes Akka Streams and actors in Websockets and -# other streaming HTTP responses. -akka { - # "akka.log-config-on-start" is extraordinarly useful because it log the complete - # configuration at INFO level, including defaults and overrides, so it s worth - # putting at the very top. - # - # Put the following in your conf/logback.xml file: - # - # - # - # And then uncomment this line to debug the configuration. - # - #log-config-on-start = true -} - -## Modules -# https://www.playframework.com/documentation/latest/Modules -# ~~~~~ -# Control which modules are loaded when Play starts. Note that modules are -# the replacement for "GlobalSettings", which are deprecated in 2.5.x. -# Please see https://www.playframework.com/documentation/latest/GlobalSettings -# for more information. -# -# You can also extend Play functionality by using one of the publically available -# Play modules: https://playframework.com/documentation/latest/ModuleDirectory -play.modules { - # By default, Play will load any class called Module that is defined - # in the root package (the "app" directory), or you can define them - # explicitly below. - # If there are any built-in modules that you want to disable, you can list them here. - #enabled += my.application.Module - - # If there are any built-in modules that you want to disable, you can list them here. - #disabled += "" -} - -## Internationalisation -# https://www.playframework.com/documentation/latest/JavaI18N -# https://www.playframework.com/documentation/latest/ScalaI18N -# ~~~~~ -# Play comes with its own i18n settings, which allow the user's preferred language -# to map through to internal messages, or allow the language to be stored in a cookie. -play.i18n { - # The application languages - langs = [ "en" ] - - # Whether the language cookie should be secure or not - #langCookieSecure = true - - # Whether the HTTP only attribute of the cookie should be set to true - #langCookieHttpOnly = true -} - -## Play HTTP settings -# ~~~~~ -play.http { - ## Router - # https://www.playframework.com/documentation/latest/JavaRouting - # https://www.playframework.com/documentation/latest/ScalaRouting - # ~~~~~ - # Define the Router object to use for this application. - # This router will be looked up first when the application is starting up, - # so make sure this is the entry point. - # Furthermore, it's assumed your route file is named properly. - # So for an application router like `my.application.Router`, - # you may need to define a router file `conf/my.application.routes`. - # Default to Routes in the root package (aka "apps" folder) (and conf/routes) - #router = my.application.Router - - ## Action Creator - # https://www.playframework.com/documentation/latest/JavaActionCreator - # ~~~~~ - #actionCreator = null - - ## ErrorHandler - # https://www.playframework.com/documentation/latest/JavaRouting - # https://www.playframework.com/documentation/latest/ScalaRouting - # ~~~~~ - # If null, will attempt to load a class called ErrorHandler in the root package, - #errorHandler = null - - ## Filters - # https://www.playframework.com/documentation/latest/ScalaHttpFilters - # https://www.playframework.com/documentation/latest/JavaHttpFilters - # ~~~~~ - # Filters run code on every request. They can be used to perform - # common logic for all your actions, e.g. adding common headers. - # Defaults to "Filters" in the root package (aka "apps" folder) - # Alternatively you can explicitly register a class here. - #filters += my.application.Filters - - ## Session & Flash - # https://www.playframework.com/documentation/latest/JavaSessionFlash - # https://www.playframework.com/documentation/latest/ScalaSessionFlash - # ~~~~~ - session { - # Sets the cookie to be sent only over HTTPS. - #secure = true - - # Sets the cookie to be accessed only by the server. - #httpOnly = true - - # Sets the max-age field of the cookie to 5 minutes. - # NOTE: this only sets when the browser will discard the cookie. Play will consider any - # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, - # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. - #maxAge = 300 - - # Sets the domain on the session cookie. - #domain = "example.com" - } - - flash { - # Sets the cookie to be sent only over HTTPS. - #secure = true - - # Sets the cookie to be accessed only by the server. - #httpOnly = true - } -} - -## Netty Provider -# https://www.playframework.com/documentation/latest/SettingsNetty -# ~~~~~ -play.server.netty { - # Whether the Netty wire should be logged - #log.wire = true - - # If you run Play on Linux, you can use Netty's native socket transport - # for higher performance with less garbage. - #transport = "native" -} - -## WS (HTTP Client) -# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS -# ~~~~~ -# The HTTP client primarily used for REST APIs. The default client can be -# configured directly, but you can also create different client instances -# with customized settings. You must enable this by adding to build.sbt: -# -# libraryDependencies += ws // or javaWs if using java -# -play.ws { - # Sets HTTP requests not to follow 302 requests - #followRedirects = false - - # Sets the maximum number of open HTTP connections for the client. - #ahc.maxConnectionsTotal = 50 - - ## WS SSL - # https://www.playframework.com/documentation/latest/WsSSL - # ~~~~~ - ssl { - # Configuring HTTPS with Play WS does not require programming. You can - # set up both trustManager and keyManager for mutual authentication, and - # turn on JSSE debugging in development with a reload. - #debug.handshake = true - #trustManager = { - # stores = [ - # { type = "JKS", path = "exampletrust.jks" } - # ] - #} - } -} - -## Cache -# https://www.playframework.com/documentation/latest/JavaCache -# https://www.playframework.com/documentation/latest/ScalaCache -# ~~~~~ -# Play comes with an integrated cache API that can reduce the operational -# overhead of repeated requests. You must enable this by adding to build.sbt: -# -# libraryDependencies += cache -# -play.cache { - # If you want to bind several caches, you can bind the individually - #bindCaches = ["db-cache", "user-cache", "session-cache"] -} - -## Filters -# https://www.playframework.com/documentation/latest/Filters -# ~~~~~ -# There are a number of built-in filters that can be enabled and configured -# to give Play greater security. You must enable this by adding to build.sbt: -# -# libraryDependencies += filters -# -play.filters { - ## CORS filter configuration - # https://www.playframework.com/documentation/latest/CorsFilter - # ~~~~~ - # CORS is a protocol that allows web applications to make requests from the browser - # across different domains. - # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has - # dependencies on CORS settings. - cors { - # Filter paths by a whitelist of path prefixes - #pathPrefixes = ["/some/path", ...] - - # The allowed origins. If null, all origins are allowed. - #allowedOrigins = ["http://www.example.com"] - - # The allowed HTTP methods. If null, all methods are allowed - #allowedHttpMethods = ["GET", "POST"] - } - - ## CSRF Filter - # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter - # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter - # ~~~~~ - # Play supports multiple methods for verifying that a request is not a CSRF request. - # The primary mechanism is a CSRF token. This token gets placed either in the query string - # or body of every form submitted, and also gets placed in the users session. - # Play then verifies that both tokens are present and match. - csrf { - # Sets the cookie to be sent only over HTTPS - #cookie.secure = true - - # Defaults to CSRFErrorHandler in the root package. - #errorHandler = MyCSRFErrorHandler - } - - ## Security headers filter configuration - # https://www.playframework.com/documentation/latest/SecurityHeaders - # ~~~~~ - # Defines security headers that prevent XSS attacks. - # If enabled, then all options are set to the below configuration by default: - headers { - # The X-Frame-Options header. If null, the header is not set. - #frameOptions = "DENY" - - # The X-XSS-Protection header. If null, the header is not set. - #xssProtection = "1; mode=block" - - # The X-Content-Type-Options header. If null, the header is not set. - #contentTypeOptions = "nosniff" - - # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. - #permittedCrossDomainPolicies = "master-only" - - # The Content-Security-Policy header. If null, the header is not set. - #contentSecurityPolicy = "default-src 'self'" - } - - ## Allowed hosts filter configuration - # https://www.playframework.com/documentation/latest/AllowedHostsFilter - # ~~~~~ - # Play provides a filter that lets you configure which hosts can access your application. - # This is useful to prevent cache poisoning attacks. - hosts { - # Allow requests to example.com, its subdomains, and localhost:9000. - #allowed = [".example.com", "localhost:9000"] - } -} - -## Evolutions -# https://www.playframework.com/documentation/latest/Evolutions -# ~~~~~ -# Evolutions allows database scripts to be automatically run on startup in dev mode -# for database migrations. You must enable this by adding to build.sbt: -# -# libraryDependencies += evolutions -# -play.evolutions { - # You can disable evolutions for a specific datasource if necessary - #db.default.enabled = false -} - -## Database Connection Pool -# https://www.playframework.com/documentation/latest/SettingsJDBC -# ~~~~~ -# Play doesn't require a JDBC database to run, but you can easily enable one. -# -# libraryDependencies += jdbc -# -play.db { - # The combination of these two settings results in "db.default" as the - # default JDBC pool: - #config = "db" - #default = "default" - - # Play uses HikariCP as the default connection pool. You can override - # settings by changing the prototype: - prototype { - # Sets a fixed JDBC connection pool size of 50 - #hikaricp.minimumIdle = 50 - #hikaricp.maximumPoolSize = 50 - } -} - -## JDBC Datasource -# https://www.playframework.com/documentation/latest/JavaDatabase -# https://www.playframework.com/documentation/latest/ScalaDatabase -# ~~~~~ -# Once JDBC datasource is set up, you can work with several different -# database options: -# -# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick -# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA -# EBean: https://playframework.com/documentation/latest/JavaEbean -# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm -# -db { - # You can declare as many datasources as you want. - # By convention, the default datasource is named `default` - - # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database - default.driver = org.h2.Driver - default.url = "jdbc:h2:mem:play" - #default.username = sa - #default.password = "" - - # You can expose this datasource via JNDI if needed (Useful for JPA) - default.jndiName=DefaultDS - - # You can turn on SQL logging for any datasource - # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements - #default.logSql=true -} - -jpa.default=defaultPersistenceUnit - - -#Increase default maximum post length - used for remote listener functionality -#Can get response 413 with larger networks without setting this -# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead -#parsers.text.maxLength=10M -play.http.parser.maxMemoryBuffer=10M diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/convolutional/Activations.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/convolutional/Activations.template.scala deleted file mode 100644 index 20672a6bd..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/convolutional/Activations.template.scala +++ /dev/null @@ -1,159 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.convolutional - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object Activations_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class Activations extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply():play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.1*/(""" - - - - - - - Neural Network activations - - - """),format.raw/*25.68*/(""" - """),format.raw/*26.13*/(""" - - - - - - - - - - - - - - - - - - - - - - -
DeepLearning4j UI 
-

-
-
- -
-
- - -""")) - } - } - } - - def render(): play.twirl.api.HtmlFormat.Appendable = apply() - - def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply() - - def ref: this.type = this - -} - - -} - -/**/ -object Activations extends Activations_Scope0.Activations - /* - -- GENERATED -- - DATE: Tue May 07 21:39:40 AEST 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/convolutional/Activations.scala.html - HASH: fb7b9cdfdd31f76d21208213b839aeebac3dabe2 - MATRIX: 657->0|1724->1098|1766->1112|2133->1451|2162->1452|2204->1466|2362->1597|2391->1598|2430->1610|2462->1614|2491->1615|2533->1629|2655->1724|2684->1725|2723->1737|2758->1744|2787->1745|2829->1759|3016->1919|3045->1920|3084->1932|3122->1942|3151->1943|3193->1957|3315->2052|3344->2053|3381->2063|3494->2148|3523->2149|3565->2163|3730->2301|3759->2302 - LINES: 25->1|49->25|50->26|57->33|57->33|58->34|62->38|62->38|64->40|64->40|64->40|65->41|68->44|68->44|70->46|70->46|70->46|71->47|76->52|76->52|78->54|78->54|78->54|79->55|82->58|82->58|83->59|86->62|86->62|87->63|89->65|89->65 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/samediff/SameDiffUI.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/samediff/SameDiffUI.template.scala deleted file mode 100644 index 80191c81d..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/samediff/SameDiffUI.template.scala +++ /dev/null @@ -1,240 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.samediff - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object SameDiffUI_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class SameDiffUI extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply/*1.2*/():play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.4*/(""" -"""),format.raw/*2.1*/(""" - - - - - - - - SameDiff Graph Visualization - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - """),format.raw/*71.34*/(""" - """),format.raw/*72.9*/("""
- - -
- - """),format.raw/*98.47*/(""" - """),format.raw/*99.17*/("""
- """),format.raw/*100.81*/(""" - """),format.raw/*101.21*/("""
-
- -
-
[No File Loaded]
-
- -
- - """),format.raw/*114.85*/(""" - """),format.raw/*115.21*/("""
-
Selected Node:
-
(None)
-
- -
- -
-
- Find Node:
-
- -
- -
-
- -
- - """),format.raw/*134.87*/(""" - """),format.raw/*135.21*/("""
-

- Graph Layout: -
- - - - -
-
-
-
-
-
- - -
-
-
-
-
- - - - - - -""")) - } - } - } - - def render(): play.twirl.api.HtmlFormat.Appendable = apply() - - def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply() - - def ref: this.type = this - -} - - -} - -/**/ -object SameDiffUI extends SameDiffUI_Scope0.SameDiffUI - /* - -- GENERATED -- - DATE: Tue May 07 21:39:40 AEST 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUI.scala.html - HASH: d80f43e3f37ebbfb22768edb52560e2d6d6fcb2a - MATRIX: 561->1|657->3|685->5|4621->3938|4658->3948|6437->5729|6483->5747|6664->5959|6715->5981|7384->6685|7435->6707|8217->7526|8268->7548|10114->9365|10144->9366|10195->9388|10264->9428|10294->9429 - LINES: 20->1|25->1|26->2|95->71|96->72|122->98|123->99|124->100|125->101|138->114|139->115|158->134|159->135|191->167|191->167|192->168|193->169|193->169 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/samediff/SameDiffUIBackup.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/samediff/SameDiffUIBackup.template.scala deleted file mode 100644 index 426384ec4..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/samediff/SameDiffUIBackup.template.scala +++ /dev/null @@ -1,175 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.samediff - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object SameDiffUIBackup_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class SameDiffUIBackup extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply/*1.2*/():play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.4*/(""" -"""),format.raw/*2.1*/(""" - - - - - - - - SameDiff Graph Visualization - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
- - - """),format.raw/*77.78*/(""" - """),format.raw/*78.17*/("""
- Layout: - - - - """),format.raw/*83.76*/(""" - """),format.raw/*84.78*/(""" - """),format.raw/*85.74*/(""" - """),format.raw/*86.21*/(""" - - """),format.raw/*88.74*/(""" - """),format.raw/*89.17*/("""
-
-
-
- -
- - - - - - -""")) - } - } - } - - def render(): play.twirl.api.HtmlFormat.Appendable = apply() - - def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply() - - def ref: this.type = this - -} - - -} - -/**/ -object SameDiffUIBackup extends SameDiffUIBackup_Scope0.SameDiffUIBackup - /* - -- GENERATED -- - DATE: Sat Jan 26 18:11:00 AEDT 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUIBackup.scala.html - HASH: 5bbca606352004bcd370a7bff06c376141f0d082 - MATRIX: 573->1|669->3|697->5|4607->3948|4653->3966|4993->4333|5043->4412|5093->4487|5143->4509|5346->4737|5392->4755|5813->5147|5843->5148|5894->5170|5963->5210|5993->5211 - LINES: 20->1|25->1|26->2|101->77|102->78|107->83|108->84|109->85|110->86|112->88|113->89|126->102|126->102|127->103|128->104|128->104 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingModel.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingModel.template.scala deleted file mode 100644 index 16b37a3a9..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingModel.template.scala +++ /dev/null @@ -1,334 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.training - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object TrainingModel_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class TrainingModel extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template1[org.deeplearning4j.ui.api.I18N,play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply/*1.2*/(i18n: org.deeplearning4j.ui.api.I18N):play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.40*/(""" -"""),format.raw/*2.1*/(""" - - - - - - - - """),_display_(/*24.17*/i18n/*24.21*/.getMessage("train.pagetitle")),format.raw/*24.51*/(""" - - - - - - - - - - - - - - - - - - - - - - - -
- - - - - -
-
-
-
-
- - -
- -
-
-

"""),_display_(/*109.49*/i18n/*109.53*/.getMessage("train.model.layerInfoTable.title")),format.raw/*109.100*/("""

-
-
-
-
-
- -
- -
-
-

log - 10 """),_display_(/*128.51*/i18n/*128.55*/.getMessage("train.overview.chart.updateRatioTitleShort")),format.raw/*128.112*/(""" 0, """),_display_(/*128.165*/i18n/*128.169*/.getMessage("train.overview.charts.iteration")),format.raw/*128.215*/(""": 0

-
-
-
-
-

"""),_display_(/*133.49*/i18n/*133.53*/.getMessage("train.model.activationsChart.title")),format.raw/*133.102*/("""

-
-
-
-

"""),_display_(/*137.63*/i18n/*137.67*/.getMessage("train.model.activationsChart.titleShort")),format.raw/*137.121*/(""" - """),format.raw/*138.33*/(""": 0 - , """),_display_(/*139.47*/i18n/*139.51*/.getMessage("train.overview.charts.iteration")),format.raw/*139.97*/(""" - """),format.raw/*140.33*/(""": 0

-
-
- -
-
-

"""),_display_(/*146.49*/i18n/*146.53*/.getMessage("train.model.paramHistChart.title")),format.raw/*146.100*/("""

-
-
-
-
-
-
-
- -
-
-

"""),_display_(/*157.49*/i18n/*157.53*/.getMessage("train.model.updateHistChart.title")),format.raw/*157.101*/("""

-
-
-
-
-
-
-
- -
-
-

"""),_display_(/*168.49*/i18n/*168.53*/.getMessage("train.model.lrChart.title")),format.raw/*168.93*/("""

-
-
-
-

"""),_display_(/*172.63*/i18n/*172.67*/.getMessage("train.model.lrChart.titleShort")),format.raw/*172.112*/(""" - """),format.raw/*173.33*/(""": 0 - , """),_display_(/*174.47*/i18n/*174.51*/.getMessage("train.overview.charts.iteration")),format.raw/*174.97*/(""" - """),format.raw/*175.33*/(""": 0

-
-
- -
- - - -
-
-
-

Getting Started

-
-
- -
-
-

Overview

-

- The layer visualization UI renders network structure dynamically. Users can inspect node layer parameters by clicking on the various elements of the GUI to see general information as well as overall network information such as performance. -

-

Actions

-

On the left, you will find an interactive layer visualization.

-

-

    -
  • Clicking - Click on a layer to load network performance metrics.
  • -
  • Scrolling - - Drag the GUI with your mouse or touchpad to move the model around.
  • -
-

-
-
-
-
-
- -
- - - - - -
- -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -""")) - } - } - } - - def render(i18n:org.deeplearning4j.ui.api.I18N): play.twirl.api.HtmlFormat.Appendable = apply(i18n) - - def f:((org.deeplearning4j.ui.api.I18N) => play.twirl.api.HtmlFormat.Appendable) = (i18n) => apply(i18n) - - def ref: this.type = this - -} - - -} - -/**/ -object TrainingModel extends TrainingModel_Scope0.TrainingModel - /* - -- GENERATED -- - DATE: Tue May 07 21:39:41 AEST 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingModel.scala.html - HASH: e97fc151ff2478d42c14c4d98e2cf87aa2ee049f - MATRIX: 598->1|731->39|759->41|1677->932|1690->936|1741->966|2753->1951|2766->1955|2817->1985|2988->2129|3001->2133|3056->2167|3409->2493|3422->2497|3484->2538|4020->3046|4034->3050|4089->3083|4241->3207|4255->3211|4307->3241|4466->3372|4480->3376|4533->3407|4778->3625|4791->3629|4846->3662|4904->3692|6344->5104|6373->5105|6423->5127|6607->5283|6636->5284|6682->5302|7246->5838|7260->5842|7330->5889|7824->6355|7838->6359|7913->6411|8280->6749|8295->6753|8362->6797|8536->6942|8551->6946|8618->6990|8794->7137|8809->7141|8877->7186|9326->7607|9340->7611|9420->7668|9502->7721|9517->7725|9586->7771|9900->8057|9914->8061|9986->8110|10295->8391|10309->8395|10386->8449|10449->8483|10563->8569|10577->8573|10645->8619|10708->8653|11029->8946|11043->8950|11113->8997|11843->9699|11857->9703|11928->9751|12657->10452|12671->10456|12733->10496|13043->10778|13057->10782|13125->10827|13188->10861|13303->10948|13317->10952|13385->10998|13448->11032|18024->15579|18054->15580|18104->15601|18212->15680|18242->15681|18413->15823|18443->15824|18494->15846|18564->15887|18594->15888 - LINES: 20->1|25->1|26->2|48->24|48->24|48->24|70->46|70->46|70->46|72->48|72->48|72->48|78->54|78->54|78->54|90->66|90->66|90->66|91->67|91->67|91->67|92->68|92->68|92->68|95->71|95->71|95->71|96->72|113->89|113->89|114->90|118->94|118->94|119->95|133->109|133->109|133->109|142->118|142->118|142->118|144->120|144->120|144->120|145->121|145->121|145->121|146->122|146->122|146->122|152->128|152->128|152->128|152->128|152->128|152->128|157->133|157->133|157->133|161->137|161->137|161->137|162->138|163->139|163->139|163->139|164->140|170->146|170->146|170->146|181->157|181->157|181->157|192->168|192->168|192->168|196->172|196->172|196->172|197->173|198->174|198->174|198->174|199->175|277->253|277->253|278->254|280->256|280->256|285->261|285->261|286->262|287->263|287->263 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingOverview.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingOverview.template.scala deleted file mode 100644 index 64469ca6a..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingOverview.template.scala +++ /dev/null @@ -1,288 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.training - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object TrainingOverview_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class TrainingOverview extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template1[org.deeplearning4j.ui.api.I18N,play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply/*1.2*/(i18n: org.deeplearning4j.ui.api.I18N):play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.40*/(""" -"""),format.raw/*2.1*/(""" - - - - - - - - """),_display_(/*24.17*/i18n/*24.21*/.getMessage("train.pagetitle")),format.raw/*24.51*/(""" - - - - - - - - - - - - - - - - - - - - - -
- -
-
- -
-
-

"""),_display_(/*89.37*/i18n/*89.41*/.getMessage("train.overview.chart.scoreTitle")),format.raw/*89.87*/("""

-
-
-
-

"""),_display_(/*93.51*/i18n/*93.55*/.getMessage("train.overview.chart.scoreTitleShort")),format.raw/*93.106*/(""" - """),format.raw/*94.33*/(""": 0, """),_display_(/*94.66*/i18n/*94.70*/.getMessage("train.overview.charts.iteration")),format.raw/*94.116*/(""" - """),format.raw/*95.33*/(""": - 0

-
-
- - -
-
-

"""),_display_(/*103.37*/i18n/*103.41*/.getMessage("train.overview.perftable.title")),format.raw/*103.86*/("""

-
-
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
"""),_display_(/*108.42*/i18n/*108.46*/.getMessage("train.overview.modeltable.modeltype")),format.raw/*108.96*/("""Loading...
"""),_display_(/*112.42*/i18n/*112.46*/.getMessage("train.overview.modeltable.nLayers")),format.raw/*112.94*/("""Loading...
"""),_display_(/*116.42*/i18n/*116.46*/.getMessage("train.overview.modeltable.nParams")),format.raw/*116.94*/("""Loading...
"""),_display_(/*120.42*/i18n/*120.46*/.getMessage("train.overview.perftable.startTime")),format.raw/*120.95*/("""Loading...
"""),_display_(/*124.42*/i18n/*124.46*/.getMessage("train.overview.perftable.totalRuntime")),format.raw/*124.98*/("""Loading...
"""),_display_(/*128.42*/i18n/*128.46*/.getMessage("train.overview.perftable.lastUpdate")),format.raw/*128.96*/("""Loading...
"""),_display_(/*132.42*/i18n/*132.46*/.getMessage("train.overview.perftable.totalParamUpdates")),format.raw/*132.103*/("""Loading...
"""),_display_(/*136.42*/i18n/*136.46*/.getMessage("train.overview.perftable.updatesPerSec")),format.raw/*136.99*/("""Loading...
"""),_display_(/*140.42*/i18n/*140.46*/.getMessage("train.overview.perftable.examplesPerSec")),format.raw/*140.100*/("""Loading...
-
-
- -
- - -
- -
-
-

"""),_display_(/*154.37*/i18n/*154.41*/.getMessage("train.overview.chart.updateRatioTitle")),format.raw/*154.93*/(""": log10

-
-
-
-

"""),_display_(/*158.51*/i18n/*158.55*/.getMessage("train.overview.chart.updateRatioTitleShort")),format.raw/*158.112*/(""" - """),format.raw/*159.33*/(""": 0, log - 10 """),_display_(/*160.43*/i18n/*160.47*/.getMessage("train.overview.chart.updateRatioTitleShort")),format.raw/*160.104*/(""" - """),format.raw/*161.33*/(""": 0 - , """),_display_(/*162.39*/i18n/*162.43*/.getMessage("train.overview.charts.iteration")),format.raw/*162.89*/(""": - 0

-
- -
- - -
- -
-
-

"""),_display_(/*180.51*/i18n/*180.55*/.getMessage("train.overview.chart.stdevTitleShort")),format.raw/*180.106*/(""" - """),format.raw/*181.33*/(""": 0, log - 10 """),_display_(/*182.43*/i18n/*182.47*/.getMessage("train.overview.chart.stdevTitleShort")),format.raw/*182.98*/(""" - """),format.raw/*183.33*/(""": 0 - , """),_display_(/*184.39*/i18n/*184.43*/.getMessage("train.overview.charts.iteration")),format.raw/*184.89*/(""": - 0

-
-
- -
-
-
- - - - - - - - - - - - - - - - - - - - - -""")) - } - } - } - - def render(i18n:org.deeplearning4j.ui.api.I18N): play.twirl.api.HtmlFormat.Appendable = apply(i18n) - - def f:((org.deeplearning4j.ui.api.I18N) => play.twirl.api.HtmlFormat.Appendable) = (i18n) => apply(i18n) - - def ref: this.type = this - -} - - -} - -/**/ -object TrainingOverview extends TrainingOverview_Scope0.TrainingOverview - /* - -- GENERATED -- - DATE: Tue May 07 21:39:41 AEST 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingOverview.scala.html - HASH: 2c9149a6ec8ab4fcc1b316e6ecd5befab2c45af1 - MATRIX: 604->1|737->39|765->41|1683->932|1696->936|1747->966|2759->1951|2772->1955|2823->1985|2994->2129|3007->2133|3062->2167|3415->2493|3428->2497|3490->2538|3995->3015|4009->3019|4064->3052|4216->3176|4230->3180|4282->3210|4441->3341|4455->3345|4508->3376|4753->3594|4766->3598|4821->3631|4879->3661|6460->5215|6473->5219|6540->5265|6791->5489|6804->5493|6877->5544|6939->5578|6999->5611|7012->5615|7080->5661|7142->5695|7536->6061|7550->6065|7617->6110|7894->6359|7908->6363|7980->6413|8205->6610|8219->6614|8289->6662|8512->6857|8526->6861|8596->6909|8819->7104|8833->7108|8904->7157|9129->7354|9143->7358|9217->7410|9445->7610|9459->7614|9531->7664|9757->7862|9771->7866|9851->7923|10084->8128|10098->8132|10173->8185|10402->8386|10416->8390|10493->8444|11042->8965|11056->8969|11130->9021|11422->9285|11436->9289|11516->9346|11579->9380|11696->9469|11710->9473|11790->9530|11853->9564|11956->9639|11970->9643|12038->9689|12444->10067|12458->10071|12526->10117|12858->10420|12873->10424|12952->10480|13124->10623|13139->10627|13216->10681|13384->10820|13399->10824|13474->10876|13777->11151|13791->11155|13865->11206|13928->11240|14045->11329|14059->11333|14132->11384|14195->11418|14298->11493|14312->11497|14380->11543|15679->12813|15709->12814|15760->12836|15832->12879|15862->12880|16027->13016|16057->13017|16108->13039|16181->13083|16211->13084 - LINES: 20->1|25->1|26->2|48->24|48->24|48->24|70->46|70->46|70->46|72->48|72->48|72->48|78->54|78->54|78->54|88->64|88->64|88->64|89->65|89->65|89->65|90->66|90->66|90->66|93->69|93->69|93->69|94->70|113->89|113->89|113->89|117->93|117->93|117->93|118->94|118->94|118->94|118->94|119->95|127->103|127->103|127->103|132->108|132->108|132->108|136->112|136->112|136->112|140->116|140->116|140->116|144->120|144->120|144->120|148->124|148->124|148->124|152->128|152->128|152->128|156->132|156->132|156->132|160->136|160->136|160->136|164->140|164->140|164->140|178->154|178->154|178->154|182->158|182->158|182->158|183->159|184->160|184->160|184->160|185->161|186->162|186->162|186->162|195->171|195->171|195->171|197->173|197->173|197->173|198->174|198->174|198->174|199->175|199->175|199->175|204->180|204->180|204->180|205->181|206->182|206->182|206->182|207->183|208->184|208->184|208->184|232->208|232->208|233->209|234->210|234->210|239->215|239->215|240->216|241->217|241->217 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingSystem.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingSystem.template.scala deleted file mode 100644 index 9ca429f48..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/training/TrainingSystem.template.scala +++ /dev/null @@ -1,349 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.training - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object TrainingSystem_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class TrainingSystem extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template1[org.deeplearning4j.ui.api.I18N,play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply/*1.2*/(i18n: org.deeplearning4j.ui.api.I18N):play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.40*/(""" -"""),format.raw/*2.1*/(""" - - - - - - - - """),_display_(/*24.17*/i18n/*24.21*/.getMessage("train.pagetitle")),format.raw/*24.51*/(""" - - - - - - - - - - - - - - - - - - - - - -
- - - -
- -
- -
-
-

"""),_display_(/*92.41*/i18n/*92.45*/.getMessage("train.system.title")),format.raw/*92.78*/("""

-
- - -
-
-
- - -
-
- - -
- -
-
-

"""),_display_(/*109.61*/i18n/*109.65*/.getMessage("train.system.chart.systemMemoryTitle")),format.raw/*109.116*/(""" """),format.raw/*109.117*/("""%

-
-
-
-

"""),_display_(/*113.75*/i18n/*113.79*/.getMessage("train.system.chart.memoryShort")),format.raw/*113.124*/(""": 0, - """),_display_(/*114.58*/i18n/*114.62*/.getMessage("train.overview.charts.iteration")),format.raw/*114.108*/(""": 0

-
-
- - -
-
-

"""),_display_(/*121.61*/i18n/*121.65*/.getMessage("train.system.chart.gpuMemoryTitle")),format.raw/*121.113*/(""" """),format.raw/*121.114*/("""%

-
-
-
-

"""),_display_(/*125.75*/i18n/*125.79*/.getMessage("train.system.chart.memoryShort")),format.raw/*125.124*/(""": 0, - """),_display_(/*126.58*/i18n/*126.62*/.getMessage("train.overview.charts.iteration")),format.raw/*126.108*/(""": 0

-
-
- -
- - -
- - -
-
-

"""),_display_(/*138.61*/i18n/*138.65*/.getMessage("train.system.hwTable.title")),format.raw/*138.106*/("""

-
-
- - - - - - - - - - - - - - - - - - - - - -
"""),_display_(/*144.70*/i18n/*144.74*/.getMessage("train.system.hwTable.jvmCurrent")),format.raw/*144.120*/(""""""),_display_(/*145.70*/i18n/*145.74*/.getMessage("train.system.hwTable.jvmMax")),format.raw/*145.116*/(""""""),_display_(/*146.70*/i18n/*146.74*/.getMessage("train.system.hwTable.offHeapCurrent")),format.raw/*146.124*/(""""""),_display_(/*147.70*/i18n/*147.74*/.getMessage("train.system.hwTable.offHeapMax")),format.raw/*147.120*/(""""""),_display_(/*148.70*/i18n/*148.74*/.getMessage("train.system.hwTable.jvmProcs")),format.raw/*148.118*/(""""""),_display_(/*149.70*/i18n/*149.74*/.getMessage("train.system.hwTable.computeDevices")),format.raw/*149.124*/("""
Loading...Loading...Loading...Loading...Loading...Loading...
-
-
- -
- -
- - -
-
-

"""),_display_(/*173.61*/i18n/*173.65*/.getMessage("train.system.swTable.title")),format.raw/*173.106*/("""

-
-
- - - - - - - - - - - - - - - - - - - - - - - -
"""),_display_(/*179.70*/i18n/*179.74*/.getMessage("train.system.swTable.hostname")),format.raw/*179.118*/(""""""),_display_(/*180.70*/i18n/*180.74*/.getMessage("train.system.swTable.os")),format.raw/*180.112*/(""""""),_display_(/*181.70*/i18n/*181.74*/.getMessage("train.system.swTable.osArch")),format.raw/*181.116*/(""""""),_display_(/*182.70*/i18n/*182.74*/.getMessage("train.system.swTable.jvmName")),format.raw/*182.117*/(""""""),_display_(/*183.70*/i18n/*183.74*/.getMessage("train.system.swTable.jvmVersion")),format.raw/*183.120*/(""""""),_display_(/*184.70*/i18n/*184.74*/.getMessage("train.system.swTable.nd4jBackend")),format.raw/*184.121*/(""""""),_display_(/*185.70*/i18n/*185.74*/.getMessage("train.system.swTable.nd4jDataType")),format.raw/*185.122*/("""
Loading...Loading...Loading...Loading...Loading...Loading...Loading...
-
-
- -
- - """),format.raw/*205.73*/(""" - """),format.raw/*206.82*/(""" - """),format.raw/*207.73*/(""" - """),format.raw/*208.79*/(""" - """),format.raw/*209.88*/(""" - """),format.raw/*210.59*/(""" - """),format.raw/*211.78*/(""" - """),format.raw/*212.92*/(""" - """),format.raw/*213.68*/(""" - """),format.raw/*214.69*/(""" - """),format.raw/*215.79*/(""" - """),format.raw/*216.70*/(""" - """),format.raw/*217.69*/(""" - """),format.raw/*218.68*/(""" - """),format.raw/*219.69*/(""" - """),format.raw/*220.108*/(""" - """),format.raw/*221.70*/(""" - """),format.raw/*222.69*/(""" - """),format.raw/*223.65*/(""" - """),format.raw/*224.59*/(""" - """),format.raw/*225.55*/(""" - """),format.raw/*226.51*/(""" - """),format.raw/*227.37*/("""
- -
-
-
- -
-
-
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -""")) - } - } - } - - def render(i18n:org.deeplearning4j.ui.api.I18N): play.twirl.api.HtmlFormat.Appendable = apply(i18n) - - def f:((org.deeplearning4j.ui.api.I18N) => play.twirl.api.HtmlFormat.Appendable) = (i18n) => apply(i18n) - - def ref: this.type = this - -} - - -} - -/**/ -object TrainingSystem extends TrainingSystem_Scope0.TrainingSystem - /* - -- GENERATED -- - DATE: Tue May 07 21:39:41 AEST 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingSystem.scala.html - HASH: a5fe2fb79cf26f96000f494801fcbb6b9f354dd5 - MATRIX: 600->1|733->39|761->41|1679->932|1692->936|1743->966|2755->1951|2768->1955|2819->1985|2990->2129|3003->2133|3058->2167|3411->2493|3424->2497|3486->2538|3991->3015|4005->3019|4060->3052|4212->3176|4226->3180|4278->3210|4437->3341|4451->3345|4504->3376|4749->3594|4762->3598|4817->3631|4875->3661|6525->5284|6538->5288|6592->5321|6851->5552|6865->5556|6928->5597|6958->5598|7821->6433|7835->6437|7909->6488|7940->6489|8316->6837|8330->6841|8398->6886|8517->6977|8531->6981|8600->7027|9078->7477|9092->7481|9163->7529|9194->7530|9567->7875|9581->7879|9649->7924|9769->8016|9783->8020|9852->8066|10487->8673|10501->8677|10565->8718|11024->9149|11038->9153|11107->9199|11211->9275|11225->9279|11290->9321|11394->9397|11408->9401|11481->9451|11585->9527|11599->9531|11668->9577|11772->9653|11786->9657|11853->9701|11957->9777|11971->9781|12044->9831|13693->11452|13707->11456|13771->11497|14230->11928|14244->11932|14311->11976|14415->12052|14429->12056|14490->12094|14594->12170|14608->12174|14673->12216|14777->12292|14791->12296|14857->12339|14961->12415|14975->12419|15044->12465|15148->12541|15162->12545|15232->12592|15336->12668|15350->12672|15421->12720|16821->14119|16892->14202|16967->14276|17046->14356|17129->14445|17208->14505|17287->14584|17370->14677|17457->14746|17548->14816|17643->14896|17734->14967|17821->15037|17908->15106|17999->15176|18095->15285|18186->15356|18273->15426|18356->15492|18435->15552|18510->15608|18581->15660|18648->15698|20855->17876|20885->17877|20936->17899|21222->18156|21252->18157|21420->18296|21450->18297|21501->18319|21572->18361|21602->18362 - LINES: 20->1|25->1|26->2|48->24|48->24|48->24|70->46|70->46|70->46|72->48|72->48|72->48|78->54|78->54|78->54|88->64|88->64|88->64|89->65|89->65|89->65|90->66|90->66|90->66|93->69|93->69|93->69|94->70|116->92|116->92|116->92|118->94|118->94|118->94|118->94|133->109|133->109|133->109|133->109|137->113|137->113|137->113|138->114|138->114|138->114|145->121|145->121|145->121|145->121|149->125|149->125|149->125|150->126|150->126|150->126|162->138|162->138|162->138|168->144|168->144|168->144|169->145|169->145|169->145|170->146|170->146|170->146|171->147|171->147|171->147|172->148|172->148|172->148|173->149|173->149|173->149|197->173|197->173|197->173|203->179|203->179|203->179|204->180|204->180|204->180|205->181|205->181|205->181|206->182|206->182|206->182|207->183|207->183|207->183|208->184|208->184|208->184|209->185|209->185|209->185|229->205|230->206|231->207|232->208|233->209|234->210|235->211|236->212|237->213|238->214|239->215|240->216|241->217|242->218|243->219|244->220|245->221|246->222|247->223|248->224|249->225|250->226|251->227|286->262|286->262|287->263|293->269|293->269|298->274|298->274|299->275|300->276|300->276 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/tsne/Tsne.template.scala b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/tsne/Tsne.template.scala deleted file mode 100644 index e9a1225e8..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/scala/org/deeplearning4j/ui/views/html/tsne/Tsne.template.scala +++ /dev/null @@ -1,296 +0,0 @@ - -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.views.html.tsne - -import play.twirl.api._ -import play.twirl.api.TemplateMagic._ - - - object Tsne_Scope0 { -import models._ -import controllers._ -import play.api.i18n._ -import views.html._ -import play.api.templates.PlayMagic._ -import play.api.mvc._ -import play.api.data._ - -class Tsne extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] { - - /**/ - def apply():play.twirl.api.HtmlFormat.Appendable = { - _display_ { - { - - -Seq[Any](format.raw/*1.1*/(""" - - - - - - - T-SNE renders - - - - - - - - """),format.raw/*30.69*/(""" - """),format.raw/*31.9*/(""" - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
DeepLearning4j UI  Available sessions: - -  
- -
-
-
-
-
-
-
-
-
Upload a file to UI server.
-
-
- - - - -
-
-
- - """),format.raw/*219.53*/(""" - """),format.raw/*220.96*/(""" - """),format.raw/*221.42*/(""" - """),format.raw/*222.59*/(""" - """),format.raw/*223.68*/(""" - """),format.raw/*224.27*/(""" - """),format.raw/*225.23*/(""" - """),format.raw/*226.9*/("""
- - -""")) - } - } - } - - def render(): play.twirl.api.HtmlFormat.Appendable = apply() - - def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply() - - def ref: this.type = this - -} - - -} - -/**/ -object Tsne extends Tsne_Scope0.Tsne - /* - -- GENERATED -- - DATE: Tue May 07 21:39:41 AEST 2019 - SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/tsne/Tsne.scala.html - HASH: 9c47ea63f8745ed1b8d2d67657f185532c1a0c78 - MATRIX: 634->0|1937->1335|1974->1345|2608->1951|2637->1952|2679->1966|2801->2061|2830->2062|2869->2074|2904->2081|2933->2082|2975->2096|3162->2256|3191->2257|3230->2269|3268->2279|3297->2280|3339->2294|3461->2389|3490->2390|3529->2402|3562->2407|3591->2408|3633->2422|3791->2553|3820->2554|3859->2566|3893->2572|3922->2573|3964->2587|4079->2675|4108->2676|4147->2688|4182->2695|4211->2696|4253->2710|4307->2737|4336->2738|4375->2750|4406->2753|4435->2754|4477->2768|4567->2831|4596->2832|4635->2844|4667->2848|4696->2849|4738->2863|4900->2998|4929->2999|4968->3011|5002->3017|5031->3018|5074->3032|5124->3054|5154->3055|5194->3067|5228->3072|5258->3073|5301->3087|5423->3181|5453->3182|5491->3192|5588->3260|5618->3261|5661->3275|5729->3314|5759->3315|5806->3333|6092->3590|6122->3591|6191->3631|6221->3632|6268->3650|6341->3694|6371->3695|6422->3717|6590->3856|6620->3857|6675->3883|6971->4150|7001->4151|7060->4181|7148->4240|7178->4241|7229->4263|7259->4264|7312->4288|7481->4428|7511->4429|7564->4453|7594->4454|7645->4476|7742->4544|7772->4545|7815->4559|7845->4560|7924->4610|7954->4611|8001->4629|8037->4636|8067->4637|8118->4659|8284->4796|8314->4797|8354->4808|8384->4809|8516->4912|8546->4913|8601->4939|8718->5027|8748->5028|8849->5100|8879->5101|8934->5127|9106->5270|9136->5271|9183->5289|9213->5290|9258->5306|9288->5307|9330->5321|9360->5322|11465->7438|11512->7535|11559->7578|11610->7638|11661->7707|11708->7735|11751->7759|11789->7769 - LINES: 25->1|54->30|55->31|74->50|74->50|75->51|78->54|78->54|80->56|80->56|80->56|81->57|86->62|86->62|88->64|88->64|88->64|89->65|92->68|92->68|94->70|94->70|94->70|95->71|99->75|99->75|101->77|101->77|101->77|102->78|105->81|105->81|107->83|107->83|107->83|108->84|109->85|109->85|111->87|111->87|111->87|112->88|114->90|114->90|116->92|116->92|116->92|117->93|121->97|121->97|123->99|123->99|123->99|124->100|125->101|125->101|127->103|127->103|127->103|128->104|131->107|131->107|132->108|135->111|135->111|136->112|136->112|136->112|137->113|143->119|143->119|145->121|145->121|146->122|146->122|146->122|147->123|149->125|149->125|150->126|152->128|152->128|153->129|154->130|154->130|155->131|155->131|157->133|161->137|161->137|161->137|161->137|162->138|164->140|164->140|165->141|165->141|168->144|168->144|169->145|169->145|169->145|170->146|173->149|173->149|173->149|173->149|175->151|175->151|176->152|178->154|178->154|179->155|179->155|180->156|183->159|183->159|184->160|184->160|185->161|185->161|188->164|188->164|243->219|244->220|245->221|246->222|247->223|248->224|249->225|250->226 - -- GENERATED -- - */ - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUI.scala.html b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUI.scala.html deleted file mode 100644 index 3b5a79cc1..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUI.scala.html +++ /dev/null @@ -1,172 +0,0 @@ -@() - - - - - - - - - SameDiff Graph Visualization - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @*
*@ -
- - -
- - @*
*@ -
- @*
*@ -
-
- -
-
[No File Loaded]
-
- -
- - @*
*@ -
-
Selected Node:
-
(None)
-
- -
- -
-
- Find Node:
-
- -
- -
-
- -
- - @*
*@ -
-

- Graph Layout: -
- - - - -
-
-
-
-
-
- - -
-
-
-
-
- - - - - - diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java deleted file mode 100644 index b5af4c88e..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java +++ /dev/null @@ -1,96 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.play; - -import org.deeplearning4j.ui.api.UIServer; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.graph.ui.LogFileWriter; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.File; -import java.util.Arrays; - -@Ignore -public class TestSameDiffUI { - - @Ignore - @Test - public void testSameDiff() throws Exception { - -// -// File f = new File("C:/Temp/SameDiffUI/ui_data.bin"); -// f.getParentFile().mkdirs(); -// f.delete(); -// -// SameDiff sd = SameDiff.create(); -// SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); -// SDVariable w = sd.var("w", DataType.FLOAT, 3,4); -// SDVariable b = sd.var("b", DataType.FLOAT, 1, 4); -// -// SDVariable z = in.mmul(w).add(b); -// SDVariable a = sd.nn().tanh(z); -// -// LogFileWriter lfw = new LogFileWriter(f); -// lfw.writeGraphStructure(sd); -// lfw.writeFinishStaticMarker(); -// -// //Append a number of events -// lfw.registerEventName("accuracy"); -// lfw.registerEventName("precision"); -// long t = System.currentTimeMillis(); -// for( int iter=0; iter<50; iter++) { -// double d = Math.cos(0.1*iter); -// d *= d; -// lfw.writeScalarEvent("accuracy", t + iter, iter, 0, d); -// -// double prec = Math.min(0.05 * iter, 1.0); -// lfw.writeScalarEvent("precision", t+iter, iter, 0, prec); -// } -// -// //Add some histograms: -// lfw.registerEventName("histogramDiscrete"); -// lfw.registerEventName("histogramEqualSpacing"); -// lfw.registerEventName("histogramCustomBins"); -// for( int i=0; i<3; i++ ){ -// INDArray discreteY = Nd4j.createFromArray(0, 1, 2); -// lfw.writeHistogramEventDiscrete("histogramDiscrete", t+i, i, 0, Arrays.asList("zero", "one", "two"), discreteY); -// -// INDArray eqSpacingY = Nd4j.createFromArray(-0.5 + 0.5 * i, 0.75 * i + i, 1.0 * i + 1.0); -// lfw.writeHistogramEventEqualSpacing("histogramEqualSpacing", t+i, i, 0, 0.0, 1.0, eqSpacingY); -// -// INDArray customBins = Nd4j.createFromArray(new double[][]{ -// {0.0, 0.5, 0.9}, -// {0.2, 0.55, 1.0} -// }); -// System.out.println(Arrays.toString(customBins.data().asFloat())); -// System.out.println(customBins.shapeInfoToString()); -// lfw.writeHistogramEventCustomBins("histogramCustomBins", t+i, i, 0, customBins, eqSpacingY); -// } - - - UIServer uiServer = UIServer.getInstance(); - - - Thread.sleep(1_000_000_000); - } - -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml index e46ec021e..654e34dc1 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml @@ -134,7 +134,7 @@ org.deeplearning4j - deeplearning4j-ui_2.11 + deeplearning4j-ui ${deeplearning4j.version} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml index 087466efc..d829971b4 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml @@ -1,5 +1,6 @@ - @@ -25,165 +24,43 @@ 4.0.0 - deeplearning4j-play_2.11 - deeplearning4j-play - - - - 2.11.12 - 2.11 - - - - - - buildUiTemplates - - false - - - - - com.google.code.play2-maven-plugin - play2-maven-plugin - ${maven-play2-plugin.version} - - - - - GenerateTemplates - compile - - - - template-compile - - - - - - - - maven-resources-plugin - ${maven-resources-plugin.version} - - - CopyTemplates - compile - - copy-resources - - - ${basedir}/src/main/scala/org/deeplearning4j/ui/views/html - - - ${basedir}/target/twirl/main/org/deeplearning4j/ui/views/html - true - - - - - - - - - - - - test-nd4j-native - - - org.nd4j - nd4j-native - ${project.version} - test - - - - - - test-nd4j-cuda-10.1 - - - org.nd4j - nd4j-cuda-10.1 - ${project.version} - test - - - - + deeplearning4j-vertx + + io.vertx + vertx-core + ${vertx.version} + + + + io.vertx + vertx-web + ${vertx.version} + + + + org.deeplearning4j + deeplearning4j-core + ${project.version} + + org.deeplearning4j deeplearning4j-ui-model ${project.version} + - javax.ws.rs - javax.ws.rs-api - ${ws.rs.version} - - - org.scala-lang - scala-library - ${scala.version} - - - org.scala-lang - scala-reflect - ${scala.version} - - - com.typesafe.play - play-java_2.11 - ${playframework.version} - - - com.google.code.findbugs - jsr305 - - - org.slf4j - jul-to-slf4j - - - org.slf4j - jcl-over-slf4j - - - net.jodah - typetools - - + ch.qos.logback + logback-classic + test - net.jodah - typetools - ${jodah.typetools.version} - - - - com.typesafe.play - play-netty-server_2.11 - ${playframework.version} - - - - org.fusesource.leveldbjni - leveldbjni-all - ${leveldb.version} + org.freemarker + freemarker + 2.3.29 @@ -192,24 +69,8 @@ ${jcommander.version} - - ch.qos.logback - logback-classic - compile - - - org.slf4j - jul-to-slf4j - compile - 1.7.21 - - - - junit - junit - test - + @@ -306,11 +167,6 @@ cytoscape-cola 2.3.0 - - org.webjars.npm - webcola - 3.1.3 - org.webjars.npm cytoscape-cose-bilkent @@ -442,68 +298,8 @@ flatbuffers 1.9.0 - - - - javax.xml.bind - jaxb-api - ${jaxb.version} - provided - - - - - - net.alchim31.maven - scala-maven-plugin - - ${scala.version} - - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-source - compile - - add-source - - - - src/main/scala - src/main/java - src/main/views - - - - - - - - com.google.code.sbt-compiler-maven-plugin - sbt-compiler-maven-plugin - ${sbt-compiler-maven-plugin.version} - - - - - org.apache.maven.plugins - maven-compiler-plugin - - 1.8 - 1.8 - - - - - - - @**@ diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/SameDiffUI.html b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/SameDiffUI.html new file mode 100644 index 000000000..1d29f7827 --- /dev/null +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/SameDiffUI.html @@ -0,0 +1,169 @@ + + + + + + + + + SameDiff Graph Visualization + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
[No File Loaded]
+
+ +
+ +
+
Selected Node:
+
(None)
+
+ +
+ +
+
+ Find Node:
+
+ +
+ +
+
+ +
+ +
+

+ Graph Layout: +
+ + + + +
+
+
+
+
+
+ + +
+
+
+
+
+ + + + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingModel.scala.html b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/TrainingModel.html.ftl similarity index 85% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingModel.scala.html rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/TrainingModel.html.ftl index 81a63d26a..af302e14d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingModel.scala.html +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/TrainingModel.html.ftl @@ -1,4 +1,3 @@ -@(i18n: org.deeplearning4j.ui.api.I18N)
-

@i18n.getMessage("train.overview.perftable.title")

+

${train\.overview\.perftable\.title}

- + - + - + - + - + - + - + - + - +
@i18n.getMessage("train.overview.modeltable.modeltype")${train\.overview\.modeltable\.modeltype} Loading...
@i18n.getMessage("train.overview.modeltable.nLayers")${train\.overview\.modeltable\.nLayers} Loading...
@i18n.getMessage("train.overview.modeltable.nParams")${train\.overview\.modeltable\.nParams} Loading...
@i18n.getMessage("train.overview.perftable.startTime")${train\.overview\.perftable\.startTime} Loading...
@i18n.getMessage("train.overview.perftable.totalRuntime")${train\.overview\.perftable\.totalRuntime} Loading...
@i18n.getMessage("train.overview.perftable.lastUpdate")${train\.overview\.perftable\.lastUpdate} Loading...
@i18n.getMessage("train.overview.perftable.totalParamUpdates")${train\.overview\.perftable\.totalParamUpdates} Loading...
@i18n.getMessage("train.overview.perftable.updatesPerSec")${train\.overview\.perftable\.updatesPerSec} Loading...
@i18n.getMessage("train.overview.perftable.examplesPerSec")${train\.overview\.perftable\.examplesPerSec} Loading...
@@ -151,15 +151,15 @@
-

@i18n.getMessage("train.overview.chart.updateRatioTitle"): log10

+

${train\.overview\.chart\.updateRatioTitle}: log10

-

@i18n.getMessage("train.overview.chart.updateRatioTitleShort") +

${train\.overview\.chart\.updateRatioTitleShort} : 0, log - 10 @i18n.getMessage("train.overview.chart.updateRatioTitleShort") + 10 ${train\.overview\.chart\.updateRatioTitleShort} : 0 - , @i18n.getMessage("train.overview.charts.iteration"): + , ${train\.overview\.charts\.iteration}: 0

@@ -168,20 +168,20 @@
-

@i18n.getMessage("train.overview.chart.stdevTitleShort") +

${train\.overview\.chart\.stdevTitleShort} : 0, log - 10 @i18n.getMessage("train.overview.chart.stdevTitleShort") + 10 ${train\.overview\.chart\.stdevTitleShort} : 0 - , @i18n.getMessage("train.overview.charts.iteration"): + , ${train\.overview\.charts\.iteration}: 0

diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingSystem.scala.html b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/TrainingSystem.html.ftl similarity index 76% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingSystem.scala.html rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/TrainingSystem.html.ftl index c68032f51..569fc02b4 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingSystem.scala.html +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/TrainingSystem.html.ftl @@ -1,4 +1,3 @@ -@(i18n: org.deeplearning4j.ui.api.I18N)
-

@i18n.getMessage("train.system.chart.gpuMemoryTitle") %

+

${train\.system\.chart\.gpuMemoryTitle} %

-

@i18n.getMessage("train.system.chart.memoryShort"): 0, - @i18n.getMessage("train.overview.charts.iteration"): 0

+

${train\.system\.chart\.memoryShort}: 0, + ${train\.overview\.charts\.iteration}: 0

@@ -135,18 +134,18 @@
-

@i18n.getMessage("train.system.hwTable.title")

+

${train\.system\.hwTable\.title}

- - - - - - + + + + + + @@ -170,19 +169,19 @@
-

@i18n.getMessage("train.system.swTable.title")

+

${train\.system\.swTable\.title}

@i18n.getMessage("train.system.hwTable.jvmCurrent")@i18n.getMessage("train.system.hwTable.jvmMax")@i18n.getMessage("train.system.hwTable.offHeapCurrent")@i18n.getMessage("train.system.hwTable.offHeapMax")@i18n.getMessage("train.system.hwTable.jvmProcs")@i18n.getMessage("train.system.hwTable.computeDevices")${train\.system\.hwTable\.jvmCurrent}${train\.system\.hwTable\.jvmMax}${train\.system\.hwTable\.offHeapCurrent}${train\.system\.hwTable\.offHeapMax}${train\.system\.hwTable\.jvmProcs}${train\.system\.hwTable\.computeDevices}
- - - - - - - + + + + + + + @@ -201,29 +200,6 @@ - - @**@ - @*
*@ - @*
*@ - @*
*@ - @*

GPU Information

*@ - @*
*@ - @*
*@ - @*
@i18n.getMessage("train.system.swTable.hostname")@i18n.getMessage("train.system.swTable.os")@i18n.getMessage("train.system.swTable.osArch")@i18n.getMessage("train.system.swTable.jvmName")@i18n.getMessage("train.system.swTable.jvmVersion")@i18n.getMessage("train.system.swTable.nd4jBackend")@i18n.getMessage("train.system.swTable.nd4jDataType")${train\.system\.swTable\.hostname}${train\.system\.swTable\.os}${train\.system\.swTable\.osArch}${train\.system\.swTable\.jvmName}${train\.system\.swTable\.jvmVersion}${train\.system\.swTable\.nd4jBackend}${train\.system\.swTable\.nd4jDataType}
*@ - @**@ - @**@ - @**@ - @**@ - @**@ - @**@ - @**@ - @**@ - @**@ - @**@ - @*
?
Loading...
*@ - @*
*@ - @*
*@ - @*
*@
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/tsne/Tsne.scala.html b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/Tsne.html similarity index 91% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/tsne/Tsne.scala.html rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/Tsne.html index 5a0cd065c..76e1d7f96 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/tsne/Tsne.scala.html +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/templates/Tsne.html @@ -24,11 +24,10 @@ - + - @**@ - + @@ -215,14 +214,6 @@
- - @*
*@ - @*
If a file is already present on the server, specify the path/name.
*@ - @*
*@ - @**@ - @**@ - @*
*@ - @*
*@
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestRemoteReceiver.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java similarity index 91% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestRemoteReceiver.java rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java index eaf6cbcfe..a5469bd1d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestRemoteReceiver.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,10 +15,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.ui.play; +package org.deeplearning4j.ui; import org.deeplearning4j.api.storage.Persistable; -import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.StorageMetaData; import org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter; import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter; @@ -40,7 +40,6 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.rmi.Remote; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -65,19 +64,24 @@ public class TestRemoteReceiver { UIServer s = UIServer.getInstance(); + Thread.sleep(1000); s.enableRemoteListener(collectionRouter, false); try(RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000")) { //Use closeable to ensure async thread is shut down SbeStatsReport update1 = new SbeStatsReport(); + update1.setMemoryUsePresent(true); update1.setDeviceCurrentBytes(new long[]{1, 2}); + update1.setDeviceMaxBytes(new long[]{100, 200}); update1.reportIterationCount(10); update1.reportIDs("sid", "tid", "wid", 123456); update1.reportPerformance(10, 20, 30, 40, 50); SbeStatsReport update2 = new SbeStatsReport(); + update2.setMemoryUsePresent(true); update2.setDeviceCurrentBytes(new long[]{3, 4}); + update2.setDeviceMaxBytes(new long[]{300, 400}); update2.reportIterationCount(20); update2.reportIDs("sid2", "tid2", "wid2", 123456); update2.reportPerformance(11, 21, 31, 40, 50); @@ -89,6 +93,9 @@ public class TestRemoteReceiver { SbeStatsInitializationReport init1 = new SbeStatsInitializationReport(); init1.reportIDs("sid", "wid", "tid", 3145253452L); init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253"); + init1.setHwDeviceTotalMemory(new long[]{1,2}); + init1.setHwDeviceDescription(new String[]{"d1", "d2"}); + init1.setHasHardwareInfo(true); remoteRouter.putUpdate(update1); Thread.sleep(100); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java new file mode 100644 index 000000000..a9da39dbc --- /dev/null +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java @@ -0,0 +1,106 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.ui; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.ui.api.UIServer; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.graph.ui.LogFileWriter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.File; +import java.util.Arrays; + +@Ignore +@Slf4j +public class TestSameDiffUI { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + + @Ignore + @Test + public void testSameDiff() throws Exception { + File dir = testDir.newFolder(); + File f = new File(dir, "ui_data.bin"); + log.info("File path: {}", f.getAbsolutePath()); + + f.getParentFile().mkdirs(); + f.delete(); + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("w", DataType.FLOAT, 3,4); + SDVariable b = sd.var("b", DataType.FLOAT, 1, 4); + + SDVariable z = in.mmul(w).add(b); + SDVariable a = sd.nn().tanh(z); + + LogFileWriter lfw = new LogFileWriter(f); + lfw.writeGraphStructure(sd); + lfw.writeFinishStaticMarker(); + + //Append a number of events + lfw.registerEventName("accuracy"); + lfw.registerEventName("precision"); + long t = System.currentTimeMillis(); + for( int iter=0; iter<50; iter++) { + double d = Math.cos(0.1*iter); + d *= d; + lfw.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t + iter, iter, 0, d); + + double prec = Math.min(0.05 * iter, 1.0); + lfw.writeScalarEvent("precision", LogFileWriter.EventSubtype.EVALUATION, t+iter, iter, 0, prec); + } + + //Add some histograms: + lfw.registerEventName("histogramDiscrete"); + lfw.registerEventName("histogramEqualSpacing"); + lfw.registerEventName("histogramCustomBins"); + for( int i=0; i<3; i++ ){ + INDArray discreteY = Nd4j.createFromArray(0, 1, 2); + lfw.writeHistogramEventDiscrete("histogramDiscrete", LogFileWriter.EventSubtype.TUNING_METRIC, t+i, i, 0, Arrays.asList("zero", "one", "two"), discreteY); + + INDArray eqSpacingY = Nd4j.createFromArray(-0.5 + 0.5 * i, 0.75 * i + i, 1.0 * i + 1.0); + lfw.writeHistogramEventEqualSpacing("histogramEqualSpacing", LogFileWriter.EventSubtype.TUNING_METRIC, t+i, i, 0, 0.0, 1.0, eqSpacingY); + + INDArray customBins = Nd4j.createFromArray(new double[][]{ + {0.0, 0.5, 0.9}, + {0.2, 0.55, 1.0} + }); + System.out.println(Arrays.toString(customBins.data().asFloat())); + System.out.println(customBins.shapeInfoToString()); + lfw.writeHistogramEventCustomBins("histogramCustomBins", LogFileWriter.EventSubtype.TUNING_METRIC, t+i, i, 0, customBins, eqSpacingY); + } + + + UIServer uiServer = UIServer.getInstance(); + + + Thread.sleep(1_000_000_000); + } + +} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java similarity index 94% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUI.java rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index b3c3ed5cb..a60147d0f 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,7 +15,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.ui.play; +package org.deeplearning4j.ui; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; @@ -41,15 +42,13 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Created by Alex on 08/10/2016. */ @Ignore -public class TestPlayUI { +public class TestVertxUI { @Before public void setUp() throws Exception { UIServer.stopInstance(); @@ -61,14 +60,14 @@ public class TestPlayUI { StatsStorage ss = new InMemoryStatsStorage(); - PlayUIServer uiServer = (PlayUIServer) UIServer.getInstance(); + VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance(); assertEquals(9000, uiServer.getPort()); uiServer.stop(); - PlayUIServer playUIServer = new PlayUIServer(); - playUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"}); - - assertEquals(9100, playUIServer.getPort()); - playUIServer.stop(); + VertxUIServer vertxUIServer = new VertxUIServer(); +// vertxUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"}); +// +// assertEquals(9100, vertxUIServer.getPort()); +// vertxUIServer.stop(); // uiServer.attach(ss); @@ -249,7 +248,7 @@ public class TestPlayUI { Thread.sleep(100); } - Thread.sleep(100000); + Thread.sleep(1000000); } @Test diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java similarity index 94% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index 8d98ac6fc..b640be8c7 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,8 +15,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.ui.play; +package org.deeplearning4j.ui; +import io.netty.handler.codec.http.HttpResponseStatus; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -34,8 +36,8 @@ import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.function.Function; +import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; -import play.mvc.Http; import java.io.IOException; import java.io.UnsupportedEncodingException; @@ -50,14 +52,14 @@ import static org.junit.Assert.*; * @author Tamas Fenyvesi */ @Ignore -public class TestPlayUIMultiSession { +public class TestVertxUIMultiSession { @Before public void setUp() throws Exception { UIServer.stopInstance(); } @Test - public void testUIMultiSession() { + public void testUIMultiSession() throws Exception { UIServer uIServer = UIServer.getInstance(true, null); HashMap statStorageForThread = new HashMap<>(); @@ -73,6 +75,7 @@ public class TestPlayUIMultiSession { Thread training = new Thread(() -> { int layerSize = sid + 4; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .updater(new Adam(1e-2)) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) @@ -82,7 +85,7 @@ public class TestPlayUIMultiSession { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - StatsListener statsListener = new StatsListener(ss); + StatsListener statsListener = new StatsListener(ss, 1); statsListener.setSessionID(sessionId); net.setListeners(statsListener, new ScoreIterationListener(1)); @@ -100,6 +103,8 @@ public class TestPlayUIMultiSession { sessionIdForThread.put(training, sessionId); } + Thread.sleep(10000000); + for (Thread thread: statStorageForThread.keySet()) { StatsStorage ss = statStorageForThread.get(thread); String sessionId = sessionIdForThread.get(thread); @@ -112,7 +117,7 @@ public class TestPlayUIMultiSession { HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); conn.connect(); - assertEquals(Http.Status.OK, conn.getResponseCode()); + assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertTrue(uIServer.isAttached(ss)); } catch (InterruptedException | IOException e) { e.printStackTrace(); @@ -170,7 +175,7 @@ public class TestPlayUIMultiSession { HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); conn.connect(); - assertEquals(Http.Status.OK, conn.getResponseCode()); + assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId))); } } @@ -222,7 +227,7 @@ public class TestPlayUIMultiSession { HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); conn.connect(); - assertEquals(Http.Status.OK, conn.getResponseCode()); + assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertTrue(uIServer.isAttached(ss)); } @@ -261,7 +266,6 @@ public class TestPlayUIMultiSession { * */ private static class AutoDetachingStatsStorageProvider implements Function { - HashMap storageForSession = new HashMap<>(); UIServer uIServer; long autoDetachTimeoutMillis; @@ -303,5 +307,4 @@ public class TestPlayUIMultiSession { return statsStorage; } } - } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml similarity index 90% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/resources/logback.xml rename to deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml index c6f89b60a..2c204cafa 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml @@ -16,6 +16,8 @@ + + logs/application.log @@ -33,9 +35,10 @@ - - - + + + + diff --git a/deeplearning4j/deeplearning4j-ui-parent/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/pom.xml index 701d73166..678b1c318 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/pom.xml @@ -1,6 +1,7 @@ - - 1.8 - 1.8 - - org.apache.maven.plugins @@ -102,6 +94,18 @@ + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 9ce9b46a3..50c6b9b8a 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -51,7 +51,7 @@ endif() if(NOT CUDA_BLAS) # we need this definition to avoid global memory use within mkldnn - add_definitions(-DMKLDNN_ENABLE_CONCURRENT_EXEC=true) + add_definitions(-DDNNL_ENABLE_CONCURRENT_EXEC=true) # there's a chance, we have no BLAS provided externally if ("${OPENBLAS_PATH}" STREQUAL "") @@ -122,7 +122,7 @@ if(NOT CUDA_BLAS) if (${HELPERS_mkldnn}) message("Going to pull & build mkldnn") set(HAVE_MKLDNN 1) - set(MKLDNN_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE) + set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "Hack to enforce static mode" FORCE) configure_file(./CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt) execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . @@ -146,7 +146,7 @@ if(NOT CUDA_BLAS) set(mkldnn_EXT_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src) set(MKLDNN_PATH "${mkldnn_SOURCE_DIR}") include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_EXT_DIR}/include ${mkldnn_SOURCE_DIR}) - set(MKLDNN mkldnn) + set(MKLDNN dnnl) endif() endif() diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index ac0f9accf..2b773abea 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -5,11 +5,11 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn GIT_REPOSITORY https://github.com/intel/mkl-dnn.git - GIT_TAG v1.0.4 + GIT_TAG v1.1.1 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" CONFIGURE_COMMAND "" - CMAKE_ARGS -DMKLDNN_USE_MKL=ML -DMKLDNN_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\" + CMAKE_ARGS -DDNNL_USE_MKL=ML -DDNNL_LIBRARY_TYPE=STATIC -G \"Unix Makefiles\" BUILD_COMMAND "" INSTALL_COMMAND "" TEST_COMMAND "" diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index de2488f9d..cfad05b49 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -1286,6 +1286,11 @@ namespace nd4j { */ Nd4jLong sizeAt(const int dim) const; + /** + * returns stride of "dim" dimension + */ + Nd4jLong strideAt(const int dim) const; + /** * returns order of array */ diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index c4a631cf5..00a984d45 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -1439,9 +1439,21 @@ Nd4jLong NDArray::sizeAt(const int dim) const { throw std::runtime_error("Bad size index requested"); if (dim >= 0) - return this->_shapeInfo[1+dim]; + return shape::shapeOf(_shapeInfo)[dim]; else - return this->_shapeInfo[1+(this->rankOf() + dim)]; + return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::strideAt(const int dim) const { + + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::strideAt: Bad size index requested"); + + if (dim >= 0) + return shape::stride(_shapeInfo)[dim]; + else + return shape::stride(_shapeInfo)[this->rankOf() + dim]; } ////////////////////////////////////////////////////////////////////////// @@ -2760,9 +2772,9 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector // TODO: eventually we want separate tads here NDArray::prepareSpecialUse({result}, {this, other}); if(max == this) - NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); registerSpecialUse({result}, {this, other}); } diff --git a/libnd4j/blas/NativeOpExecutioner.h b/libnd4j/blas/NativeOpExecutioner.h index fb2ca58f0..b4a5fdea4 100644 --- a/libnd4j/blas/NativeOpExecutioner.h +++ b/libnd4j/blas/NativeOpExecutioner.h @@ -284,6 +284,7 @@ static void execScalarInt(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); @@ -296,6 +297,7 @@ static void execScalarInt(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index ff368d7c8..ea2b0b8a5 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -179,6 +179,7 @@ ND4J_EXPORT void execBroadcastBool( void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape); diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index dc27c1cce..75a68c984 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -156,6 +156,9 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -187,7 +190,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType) @@ -219,6 +223,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { @@ -228,8 +233,11 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); }; auto xLen = shape::length(hXShapeInfo); @@ -247,22 +255,24 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if (yType != xType || nd4j::DataType::BOOL != zType) throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType); auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); }; auto xLen = shape::length(hXShapeInfo); @@ -292,6 +302,9 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); @@ -321,12 +334,13 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); @@ -367,11 +381,13 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -403,6 +419,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); @@ -429,11 +448,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); @@ -837,11 +858,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dScalarShapeInfo, void *extraParams, bool allowParallelism) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -872,11 +895,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -904,12 +929,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dSscalarShapeInfo, void *extraParams, bool allowParallelism) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) + return; + if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); @@ -939,11 +965,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); @@ -969,12 +997,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dSscalarShapeInfo, void *extraParams, bool allowParallelism) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); @@ -1004,11 +1033,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); @@ -1126,6 +1157,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); }; @@ -1145,6 +1179,9 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); }; @@ -1164,6 +1201,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); }; @@ -1183,6 +1223,9 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); }; @@ -1202,6 +1245,9 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); }; diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index df6ccc240..e790c05d0 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -231,8 +231,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + void *extraParams, + void *hDimension, Nd4jLong *hDimensionShape, + void *dDimension, Nd4jLong *dDimensionShape) { try { auto dimension = reinterpret_cast(hDimension); int dimensionLength = static_cast(shape::length(hDimensionShape)); @@ -259,6 +260,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, + extraParams, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index fcb473820..1f074f39b 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -101,6 +101,9 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != zType && yType != zType) throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type"); if (lc == nullptr) @@ -139,6 +142,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isB(zType)) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); @@ -172,6 +178,9 @@ void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isZ(zType)) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); @@ -223,6 +232,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { @@ -233,6 +243,9 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isB(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); @@ -244,7 +257,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, dim3 launchDims(256, 256, 1024); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -260,6 +273,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { @@ -269,18 +283,18 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isB(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); if (yType != xType) throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3BI opNum:[%i]\n", opNum); - dim3 launchDims(256, 256, 1024); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -308,15 +322,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isZ(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); if (yType != xType || zType != xType) throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3B opNum:[%i]\n", opNum); - dim3 launchDims(256, 256, 1024); BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) @@ -344,6 +358,9 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isZ(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); @@ -394,8 +411,8 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3 opNum:[%i]\n", opNum); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; dim3 launchDims(256, 256, 1024); @@ -429,8 +446,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3I opNum:[%i]\n", opNum); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; dim3 launchDims(256, 256, 1024); @@ -832,16 +849,21 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto stream = lc->getCudaStream(); - dim3 launchDims(512, 512, 16384); auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - if (xType != zType) - throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); + if (shape::isEmpty(hXShapeInfo)) { + return; + } + if (xType != zType) { + throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); + } + + dim3 launchDims(512, 512, 16384); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES); // TODO: remove after the release @@ -861,16 +883,21 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto stream = lc->getCudaStream(); - dim3 launchDims(512, 512, 16384); - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - if (!DataTypeUtils::isB(zType)) - throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); + if (shape::isEmpty(hXShapeInfo)) { + return; + } + if (!DataTypeUtils::isB(zType)) { + throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); + } + + dim3 launchDims(512, 512, 16384); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES); // TODO: remove after the release @@ -896,6 +923,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc, auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + dim3 launchDims(512, 512, 2048); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); @@ -917,16 +947,21 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto stream = lc->getCudaStream(); - dim3 launchDims(512, 512, 16384); auto xRank = shape::rank(hXShapeInfo); auto zRank = shape::rank(hZShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - if (xType != zType || !DataTypeUtils::isR(xType)) - throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); + if (shape::isEmpty(hXShapeInfo)) { + return; + } + if (xType != zType || !DataTypeUtils::isR(xType)) { + throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); + } + + dim3 launchDims(512, 512, 16384); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); // TODO: remove after the release @@ -953,6 +988,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + if (!DataTypeUtils::isR(zType)) throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); @@ -1175,6 +1213,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType ) throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); @@ -1211,6 +1252,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType ) throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); @@ -1244,6 +1288,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType || zType != xType) throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); @@ -1280,6 +1327,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType || zType != xType) throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); @@ -1313,6 +1363,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); @@ -1346,6 +1399,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index cda6acbad..024012808 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -294,6 +294,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { try { @@ -313,7 +314,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers, LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } catch (std::exception &e) { diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 8812b7802..1c34f25d9 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -30,14 +30,14 @@ thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); #endif #ifdef HAVE_MKLDNN -#include +#include #endif namespace nd4j { LaunchContext::~LaunchContext() { #ifdef HAVE_MKLDNN - delete reinterpret_cast(_engine); + delete reinterpret_cast(_engine); #endif } @@ -50,7 +50,7 @@ namespace nd4j { _deviceID = 0; #ifdef HAVE_MKLDNN - _engine = new mkldnn::engine(mkldnn::engine::kind::cpu, 0); + _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); #endif } diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 891525a73..ff0a7d1b2 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -53,7 +54,7 @@ namespace nd4j { #ifndef __JAVACPP_HACK__ /** - * modif - (can be empty) vector containing a subsequence of permutation/reshaping arrays (in any order), user must take care of correctness of such arrays by himself + * modif - (can be empty) vector containing a subsequence of permutation/reshaping arrays (in any order), user must take care of correctness of such arrays by himself */ static void tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC); static nd4j::NDArray* tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector>& modifA, const std::vector>& modifB); diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index b2549e93f..2ba2cc4e0 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index fca40d564..0f495bf96 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -21,6 +21,7 @@ #include "../MmulHelper.h" #include #include +#include #include #include @@ -28,110 +29,124 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////////// -// MXK x KxN = MxN +// MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static void usualGemm(const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { - T1* A = reinterpret_cast(const_cast(vA)); - T2* B = reinterpret_cast(const_cast(vB)); - T3* C = reinterpret_cast(vC); - T3 alphaZ(alpha), betaZ(beta); - - const bool flagC = cOrder == 'f'; - const bool flagA = (flagC && transA) || (!flagC && !transA); - const bool flagB = (flagC && transB) || (!flagC && !transB); + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); - // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) - // for(uint row = 0; row < M; ++row) { + const T3 alphaZ = alpha; + const T3 betaZ = beta; - // T3* c = flagC ? (C + row) : (C + row * ldc); + const bool betaPersent = beta; - // for(uint col = 0; col < N; ++col) - // c[flagC ? col * ldc : col] = 0; + const Nd4jLong* aShapeInfo = vA->getShapeInfo(); + const Nd4jLong* bShapeInfo = vB->getShapeInfo(); + const Nd4jLong* cShapeInfo = vC->getShapeInfo(); - // for(uint i = 0; i < K; ++i) { - - // T3* b = flagB ? (B + i * ldb) : (B + i); - // T3* a = flagA ? (A + row * lda + i) : (A + row + i * lda); + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); - // if(flagC) { - // PRAGMA_OMP_SIMD - // for(uint col = 0; col < N; ++col) { - // if(betaZ) - // c[col * ldc] += a * b[flagB ? col : col * ldb] + betaZ * c[col * ldc]; - // else - // c[col * ldc] += a * b[flagB ? col : col * ldb]; - // } - // } - // else { - // PRAGMA_OMP_SIMD - // for(uint col = 0; col < N; ++col) { - // if(betaZ) - // c[col] += a * b[flagB ? col : col * ldb] + betaZ * c[col]; - // else - // c[col] += a * b[flagB ? col : col * ldb]; - // } - // } - // } - // } + const Nd4jLong cLen = vC->lengthOf(); - auto func = PRAGMA_THREADS_FOR_2D { ; - for (auto row = start_x; row < stop_x; row += inc_x) { - for (auto col = start_y; col < stop_y; col += inc_y) { - T3 *c = flagC ? (C + row + col * ldc) : (C + row * ldc + col); - T3 val = 0; - - PRAGMA_OMP_SIMD - for (uint i = 0; i < K; ++i) { - T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda); - T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i); - val += alphaZ * a * b; - } - - if (betaZ) - *c = val + betaZ * *c; - else - *c = val; - } - } - }; - - samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); -} - -////////////////////////////////////////////////////////////////////////////// -// MXN x N = M -template -static void usualGemv(const char aOrder, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { - - T1* A = reinterpret_cast(const_cast(vA)); - T2* X = reinterpret_cast(const_cast(vX)); - T3* Y = reinterpret_cast(vY); - T3 alphaZ(alpha), betaZ(beta); - - const bool flagA = aOrder == 'f'; + const int K = vA->sizeAt(aKaxis); auto func = PRAGMA_THREADS_FOR { - for (auto row = start; row < stop; row += increment) { - T3 *y = Y + row * incy; - T3 val = 0; + std::vector aCoords(2), bCoords(2), cCoords(2); - PRAGMA_OMP_SIMD - for (int i = 0; i < N; ++i) { - T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i); - T3 x = *(X + i * incx); - val += alphaZ * a * x; + for (auto i = start; i < stop; ++i) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords.data()); + + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; } - if (betaZ) - *y = val + betaZ * *y; + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + + if(betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; else - *y = val; + C[cOffset] = alphaZ * val; } }; - samediff::Threads::parallel_for(func, 0, M); + samediff::Threads::parallel_tad(func, 0, cLen); +} + + +////////////////////////////////////////////////////////////////////////////// +// MXN x N = M -> actual sequence of {M,N} axes doesn't matter +template +static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { + + const T1* A = vA->bufferAsT(); + const T2* X = vX->bufferAsT(); + T3* Y = vY->bufferAsT(); + + const T3 alphaZ = alpha; + const T3 betaZ = beta; + + const bool betaPersent = beta; + + const Nd4jLong* aShapeInfo = vA->getShapeInfo(); + const Nd4jLong* xShapeInfo = vX->getShapeInfo(); + const Nd4jLong* yShapeInfo = vY->getShapeInfo(); + + const int N = vX->lengthOf(); + const int M = vY->lengthOf(); + + const auto aMstride = vA->strideAt(aMaxis); + const auto aNstride = vA->strideAt(aMaxis == 0 ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) { + + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; + + T3 val = A[aOffset] * X[xOffset]; // first iteration + + for (uint j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; + } + + auto yOffset = i * incy; + + if(betaPersent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } + }; + + samediff::Threads::parallel_tad(func, 0, M); } ////////////////////////////////////////////////////////////////////////////// @@ -144,12 +159,17 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX, T3* Z = reinterpret_cast(vZ); T3 alphaZ(alpha), betaZ(beta); + const bool betaPersent = beta; + T3 sum = 0; PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) for(int i = 0; i < length; ++i) sum += X[i * incx] * Y[i * incy]; - - *Z = alphaZ * sum + betaZ * *Z; + + if(betaPersent) + *Z = alphaZ * sum + betaZ * *Z; + else + *Z = alphaZ * sum; } ////////////////////////////////////////////////////////////////////////////// @@ -164,16 +184,15 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con if(A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !"); if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); + throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); - const auto M = A->sizeAt(0); - const auto K = A->sizeAt(1); - const auto N = B->sizeAt(1); - const auto bRows = B->sizeAt(0); - if(C != nullptr && C->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxM: rank of C array is not equal 2 !"); - if(bRows != K) + if(B->sizeAt(0) != K) throw std::runtime_error("MmulHelper::mmulMxM: B array has wrong number of rows !"); if(C != nullptr && C->sizeAt(0) != M) throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of rows !"); @@ -181,61 +200,79 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of columns !"); if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - - const auto cOrder = C->ordering(); - - if(A->ews() != 1) - pA = pA->dup(cOrder); - if(B->ews() != 1) - pB = pB->dup(cOrder); - if(C->ews() != 1) - pC = pC->dup(cOrder); - - const auto aOrder = pA->ordering(); - const auto bOrder = pB->ordering(); - - const bool transA = aOrder != cOrder; - const bool transB = bOrder != cOrder; - - const CBLAS_ORDER blasOrder = cOrder == 'f' ? CblasColMajor : CblasRowMajor; - const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; - const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; - - const int lda = aOrder == 'f' ? M : K; - const int ldb = bOrder == 'f' ? K : N; - const int ldc = cOrder == 'f' ? M : N; - - const auto aType = pA->dataType(); - const auto bType = pB->dataType(); - const auto cType = pC->dataType(); + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); const bool hasGemm = BlasHelper::getInstance()->hasGEMM(aType); - - // we'll use platform-specific gemm here eventually. maybe tomorrow. - // TODO: put proper _gemm here - if (ABC && hasGemm && aType == DataType::FLOAT32) { - BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (float) beta, reinterpret_cast(pC->getBuffer()), ldc); - } - else if (ABC && hasGemm && aType == DataType::DOUBLE) { - BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (double) beta, reinterpret_cast(pC->getBuffer()), ldc); + + const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE; + const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32; + + if(!typeFloat && !typeDouble) { + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } else { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES); - //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - } - if(pC != C) { - C->assign(pC); - delete pC; + std::vector toDelete; + + NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; + + if(!aMcont && !aKcont) { + pA = A->dup('f'); + toDelete.push_back(pA); + aMcont = true; + } + if(!bKcont && !bNcont) { + pB = B->dup('f'); + toDelete.push_back(pB); + bKcont = true; + } + if(!cMcont && !cNcont) { + pC = C->dup('f'); + toDelete.push_back(pC); + cMcont = true; + } + + const CBLAS_ORDER blasOrder = cMcont ? CblasColMajor : CblasRowMajor; + + const bool transA = (!aMcont && cMcont) || (aMcont && !cMcont); + const bool transB = (!bKcont && cMcont) || (bKcont && !cMcont); + + const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; + const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; + + const int lda = (aMcont && aKcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = (bKcont && bNcont) ? K : !bKcont ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); + + if(typeFloat) { + BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (float) beta, pC->bufferAsT(), ldc); + } + else if(typeDouble) { + BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (double) beta, pC->bufferAsT(), ldc); + } + + if(pC != C) { + C->assign(pC); + delete pC; + } + if(pA != A) + delete pA; + if(pB != B) + delete pB; } - if(pA != A) - delete pA; - if(pB != B) - delete pB; return C; } @@ -243,6 +280,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con //////////////////////////////////////////////////////////////////////////// // MXN x N = M NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { + if (X->dataType() != A->dataType()) throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType()); @@ -254,56 +292,65 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* if(A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !"); if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); + throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); - const auto M = A->sizeAt(0); + const auto M = A->sizeAt(0); const auto N = A->sizeAt(1); - + if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); if(X->lengthOf() != N) throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !"); if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); + throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); - if(Y == nullptr) + if(Y == nullptr) Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - NDArray *pA(const_cast(A)); - if(A->ews() != 1) - pA = pA->dup(); - - CBLAS_ORDER blasOrder; - int lda; - if (pA->ordering() == 'f') {blasOrder = CblasColMajor; lda = M; } - else {blasOrder = CblasRowMajor; lda = N; } - const int incx = X->stridesOf()[xLenDim]; const int incy = Y->stridesOf()[yLenDim]; - const auto aType = pA->dataType(); + const auto aType = A->dataType(); const auto xType = X->dataType(); const auto yType = Y->dataType(); const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); const bool hasGemv = BlasHelper::getInstance()->hasGEMV(aType); - - // choose appropriate cuda gemm api depending on data types - if(AXY && hasGemv && aType == DataType::DOUBLE) { - BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->getBuffer(), lda, (double*)X->getBuffer(), incx, beta, (double*)Y->getBuffer(), incy); - } - else if(AXY && hasGemv && aType == DataType::FLOAT32) { - BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); + + const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE; + const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32; + + if(!typeDouble && !typeFloat) { + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } else { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES); - //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + + NDArray *pA(const_cast(A)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + + if(!aMcont && !aNcont) { + pA = A->dup('f'); + aMcont = true; + } + const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; + + const int lda = (aMcont && aNcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + + // choose appropriate cuda gemm api depending on data types + if(typeDouble) { + BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->getBuffer(), lda, (double*)X->getBuffer(), incx, beta, (double*)Y->getBuffer(), incy); + } + else if(typeFloat) { + BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); + } + + if(pA != A) + delete pA; } - if(pA != A) - delete pA; - return Y; } @@ -330,22 +377,327 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c if(Y->lengthOf() != length) throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); - if(Z == nullptr) + if(Z == nullptr) Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - + const Nd4jLong incx = X->stridesOf()[xLenDim]; const Nd4jLong incy = Y->stridesOf()[yLenDim]; const auto xType = X->dataType(); const auto yType = Y->dataType(); const auto zType = Z->dataType(); - + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES); //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); return Z; } +////////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +template +static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); + + const T3 alphaZ = alpha; + const T3 betaZ = beta; + + const bool betaPersent = beta; + + const Nd4jLong* aShapeInfo = vA->getShapeInfo(); + const Nd4jLong* bShapeInfo = vB->getShapeInfo(); + const Nd4jLong* cShapeInfo = vC->getShapeInfo(); + + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); + + const Nd4jLong cLen = vC->lengthOf(); + + const int K = vA->sizeAt(aKaxis); + + auto func = PRAGMA_THREADS_FOR { + + std::vector aCoords(aRank), bCoords(bRank), cCoords(cRank); + + for (auto i = start; i < stop; ++i) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords.data()); + + // calculate index of current batch + Nd4jLong batchInd; + if(cRank > 2) + batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, cBatchDims); + + // evaluate A coordinates + if(aRank > 2) + shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + if(bRank > 2) + shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + + if(betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } + }; + + samediff::Threads::parallel_tad(func, 0, cLen); +} + +////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + const int cRank = C->rankOf(); + + const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); + + std::vector aBatchDims, bBatchDims, cBatchDims; + + if(aRank > 2) + aBatchDims = ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}); + if(bRank > 2) + bBatchDims = ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}); + if(cRank > 2) + cBatchDims = ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}); + + // BUILD_TRIPLE_SELECTOR(A->dataType(), B->dataType(), C->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES); + + return C; +} + +/* +////////////////////////////////////////////////////////////////////////// +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + + // multiplication + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); + const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); + std::vector idxRanges(2 * C->rankOf()); + +// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) + for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { + + ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); + NDArray cSubArr = (*C)(idxRanges); + + if(aRank > bRank) { + NDArray aSubArr = (*A)(idxRanges); + mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); + } + else if(bRank > aRank) { + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); + } + else { + NDArray aSubArr = (*A)(idxRanges); + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); + } + } + + return C; +} + +////////////////////////////////////////////////////////////////////////////// +// MXK x KxN = MxN +template +static void usualGemm(const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { + + T1* A = reinterpret_cast(const_cast(vA)); + T2* B = reinterpret_cast(const_cast(vB)); + T3* C = reinterpret_cast(vC); + T3 alphaZ(alpha), betaZ(beta); + + const bool flagC = cOrder == 'f'; + const bool flagA = (flagC && transA) || (!flagC && !transA); + const bool flagB = (flagC && transB) || (!flagC && !transB); + + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + // for(uint row = 0; row < M; ++row) { + + // T3* c = flagC ? (C + row) : (C + row * ldc); + + // for(uint col = 0; col < N; ++col) + // c[flagC ? col * ldc : col] = 0; + + // for(uint i = 0; i < K; ++i) { + + // T3* b = flagB ? (B + i * ldb) : (B + i); + // T3* a = flagA ? (A + row * lda + i) : (A + row + i * lda); + + // if(flagC) { + // for(uint col = 0; col < N; ++col) { + // if(betaZ) + // c[col * ldc] += a * b[flagB ? col : col * ldb] + betaZ * c[col * ldc]; + // else + // c[col * ldc] += a * b[flagB ? col : col * ldb]; + // } + // } + // else { + // for(uint col = 0; col < N; ++col) { + // if(betaZ) + // c[col] += a * b[flagB ? col : col * ldb] + betaZ * c[col]; + // else + // c[col] += a * b[flagB ? col : col * ldb]; + // } + // } + // } + // } + + auto func = PRAGMA_THREADS_FOR_2D { ; + for (auto row = start_x; row < stop_x; row += inc_x) { + for (auto col = start_y; col < stop_y; col += inc_y) { + T3 *c = flagC ? (C + row + col * ldc) : (C + row * ldc + col); + T3 val = 0; + + for (uint i = 0; i < K; ++i) { + T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda); + T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i); + val += alphaZ * a * b; + } + + if (betaZ) + *c = val + betaZ * *c; + else + *c = val; + } + } + }; + + samediff::Threads::parallel_tad(func, 0, M, 1, 0, N, 1); +} + +////////////////////////////////////////////////////////////////////////////// +// MXN x N = M +template +static void usualGemv(const char aOrder, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { + + T1* A = reinterpret_cast(const_cast(vA)); + T2* X = reinterpret_cast(const_cast(vX)); + T3* Y = reinterpret_cast(vY); + T3 alphaZ(alpha), betaZ(beta); + + const bool flagA = aOrder == 'f'; + + auto func = PRAGMA_THREADS_FOR { + for (auto row = start; row < stop; row += increment) { + + T3 *y = Y + row * incy; + T3 val = 0; + + for (int i = 0; i < N; ++i) { + T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i); + T3 x = *(X + i * incx); + val += alphaZ * a * x; + } + + if (betaZ) + *y = val + betaZ * *y; + else + *y = val; + } + }; + + samediff::Threads::parallel_tad(func, 0, M); +} +*/ + //BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); diff --git a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp index c4c2fa995..dbf080ac9 100644 --- a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp +++ b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp @@ -20,6 +20,7 @@ #include #include +#include using namespace simdOps; @@ -47,36 +48,39 @@ void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) - for (Nd4jLong i = 0; i < zLen; ++i) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; ++i) { - shape::index2coords(i, zShapeInfo, zCoords.data()); + shape::index2coords(i, zShapeInfo, zCoords.data()); - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { + for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } else { - xCoords[ix--] = 0; + if (ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { + xCoords[ix--] = zCoords[iz]; + } else { + xCoords[ix--] = 0; + } + } + + if (iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { + yCoords[iy--] = zCoords[iz]; + } else { + yCoords[iy--] = 0; + } } } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } else { - yCoords[iy--] = 0; - } - } + const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } + }; - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } + samediff::Threads::parallel_for(func, 0, zLen); } template @@ -103,38 +107,40 @@ void TrueBroadcastBoolHelper::exec(const NDArray& xArr, const NDArray& yAr const Nd4jLong zLen = zArr.lengthOf(); - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + auto func = PRAGMA_THREADS_FOR { + std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) - for (Nd4jLong i = 0; i < zLen; ++i) { + shape::index2coords(i, zShapeInfo, zCoords.data()); - shape::index2coords(i, zShapeInfo, zCoords.data()); + for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { + if (ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { + xCoords[ix--] = zCoords[iz]; + } else { + xCoords[ix--] = 0; + } + } - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } else { - xCoords[ix--] = 0; + if (iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { + yCoords[iy--] = zCoords[iz]; + } else { + yCoords[iy--] = 0; + } } } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } else { - yCoords[iy--] = 0; - } - } + const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } + }; - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } + samediff::Threads::parallel_for(func, 0, zLen); } template @@ -163,36 +169,39 @@ void TrueBroadcastIntHelper::exec(const NDArray& xArr, const NDArray& yArr, N std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) - for (Nd4jLong i = 0; i < zLen; ++i) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; ++i) { - shape::index2coords(i, zShapeInfo, zCoords.data()); + shape::index2coords(i, zShapeInfo, zCoords.data()); - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { + for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } else { - xCoords[ix--] = 0; + if (ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { + xCoords[ix--] = zCoords[iz]; + } else { + xCoords[ix--] = 0; + } + } + + if (iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { + yCoords[iy--] = zCoords[iz]; + } else { + yCoords[iy--] = 0; + } } } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } else { - yCoords[iy--] = 0; - } - } + const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } + }; - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } + samediff::Threads::parallel_for(func, 0, zLen); } template diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index 8f67f0004..4b7904bca 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -86,7 +86,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } } @@ -172,7 +172,7 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } } diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 32394f705..d191c7803 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,52 +23,648 @@ #include #include "../MmulHelper.h" #include +#include +#include +#include namespace nd4j { - ////////////////////////////////////////////////////////////////////////////// -// MXK x KxN = MxN -// C array must be in f order +// MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static __global__ void usualCudaGemm(const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static __global__ void usualCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { - T1* A = reinterpret_cast(const_cast(vA)); - T2* B = reinterpret_cast(const_cast(vB)); - T3* C = reinterpret_cast(vC); + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast< T3*>(vC); + __shared__ int K; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads, *coords; __shared__ T3 alphaZ, betaZ; - __shared__ Nd4jLong strideArow, strideAcol, strideBrow, strideBcol; - const int row = blockIdx.y * blockDim.y + threadIdx.y; - const int col = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { - if(row == 0 && col == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); + + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + + betaPresent = beta; + + totalThreads = gridDim.x * blockDim.x; alphaZ = alpha; betaZ = beta; - - if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } - if(transB) { strideBrow = ldb; strideBcol = 1; } else { strideBrow = 1; strideBcol = ldb; } } - __syncthreads(); - T3 val = 0; - if (row < M && col < N) - for (int i = 0; i < K; i++) - val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow + col * strideBcol]; + auto aCoords = coords + threadIdx.x * 6; // 6 = (aRank + bRank + cRank) + auto bCoords = aCoords + 2; + auto cCoords = bCoords + 2; - C[row + col * ldc] = alphaZ * val + betaZ * C[row + col * ldc]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); + + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords); + + if(betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +__host__ static void usualGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { - usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); + usualCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); } +//////////////////////////////////////////////////////////////////////// +// MXN x N = M -> actual sequence of {M,N} axes doesn't matter +template +static __global__ void usualCudaGemv(const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, + const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { + + const T1* A = reinterpret_cast(vA); + const T2* X = reinterpret_cast(vX); + T3* Y = reinterpret_cast< T3*>(vY); + + __shared__ int M, N; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads, aNstride, aMstride; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + + N = shape::length(xShapeInfo); + M = shape::length(yShapeInfo); + + aMstride = shape::stride(aShapeInfo)[aMaxis]; + aNstride = shape::stride(aShapeInfo)[aMaxis == 0 ? 1 : 0]; + + totalThreads = gridDim.x * blockDim.x; + + betaPresent = beta; + + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); + + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < M; i += totalThreads) { + + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; + + T3 val = A[aOffset] * X[xOffset]; // first iteration + + for (uint j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; + } + + auto yOffset = i * incy; + + if(betaPresent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } +} + +//////////////////////////////////////////////////////////////////////// +template +__host__ static void usualGemv(const int blocksPerGrid, const int threadsPerBlock, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { + + usualCudaGemv<<>>(vA, aShapeInfo, vX, xShapeInfo, vY, yShapeInfo, incx, incy, aMaxis, alpha, beta); +} + + +////////////////////////////////////////////////////////////////////////////// +template +static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { + + T1* X = reinterpret_cast(const_cast(vX)); + T2* Y = reinterpret_cast(const_cast(vY)); + T3* Z = reinterpret_cast(vZ); + + extern __shared__ unsigned char shmem[]; + auto pairwiseMul = reinterpret_cast(shmem); + + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if(tid < length) + pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; + + __syncthreads(); + + if(tid == 0) { + T3 sum = 0; + for(Nd4jLong i = 0; i < length; ++i) + sum = sum + pairwiseMul[i]; + + if(beta) + *Z = (T3)alpha * sum + (T3)beta * *Z; + else + *Z = (T3)alpha * sum; + } +} + +//////////////////////////////////////////////////////////////////////// +template +__host__ static void usualDot(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { + + usualCudaDot<<>>(length, alpha, vX, incx, vY, incy, beta, vZ); +} + +////////////////////////////////////////////////////////////////////////////// +// MXK x KxN = MxN +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, double alpha, double beta, const char outOrder) { + + if(A->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); + if(B->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); + + if(C != nullptr && C->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); + if(B->sizeAt(0) != K) + throw std::runtime_error("MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); + if(C != nullptr && C->sizeAt(0) != M) + throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); + if(C != nullptr && C->sizeAt(1) != N) + throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); + + if(C == nullptr) + C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + + const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first(); + + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + + const bool typeDouble = ABC && aType == DataType::DOUBLE; + const bool typeFloat = ABC && aType == DataType::FLOAT32; + const bool typeHalf = ABC && aType == DataType::HALF && major >= 6; + const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; + const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; + + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + if(!typeDouble && !typeFloat && !typeHalf && !typeIntFloat && !typeHalfFloat) { + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank + + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + } + else { + + std::vector toDelete; + + NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; + + if(!aMcont && !aKcont) { + pA = A->dup('f'); + toDelete.push_back(pA); + aMcont = true; + } + if(!bKcont && !bNcont) { + pB = B->dup('f'); + toDelete.push_back(pB); + bKcont = true; + } + if(!cMcont) { + pC = C->dup('f'); + toDelete.push_back(pC); + cMcont = true; + } + + const bool transA = !aMcont; + const bool transB = !bKcont; + + const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = (cMcont && cNcont) ? M : pC->strideAt(1); + + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + NDArray::prepareSpecialUse({pC}, {pA, pB}); + + // choose appropriate cuda gemm api depending on data types + if(typeDouble) { + status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)pB->getSpecialBuffer(), ldb, &beta, (double*)pC->getSpecialBuffer(), ldc); + } + else if(typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); + } + else if(typeHalf) { + float16 alphaH(alpha), betaH(beta); + status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); + } + else if(typeIntFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); + } + else if(typeHalfFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); + } + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + NDArray::registerSpecialUse({pC}, {pA, pB}); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + + if(C != pC) + C->assign(pC); + + for(int i = toDelete.size() - 1; i >= 0; --i) + delete toDelete[i]; + } + + return C; +} + +//////////////////////////////////////////////////////////////////////////// +// MXN x N = M +NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { + + int xLenDim, yLenDim(0); + + if(A->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); + if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); + + const auto M = A->sizeAt(0); + const auto N = A->sizeAt(1); + + if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); + if(X->lengthOf() != N) + throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); + if(Y != nullptr && Y->lengthOf() != M) + throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array has wrong length !"); + + if(Y == nullptr) + Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); + + const int incx = X->strideAt(xLenDim); + const int incy = Y->strideAt(yLenDim); + + const auto aType = A->dataType(); + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + + const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); + + const bool typeDouble = AXY && aType == DataType::DOUBLE; + const bool typeFloat = AXY && aType == DataType::FLOAT32; + + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + + if(!typeDouble && !typeFloat) { + + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({Y}, {A, X}); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES) + NDArray::registerSpecialUse({Y}, {A, X}); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + + } + else { + + NDArray *pA(const_cast(A)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + + if(!aMcont && !aNcont) { + pA = A->dup('f'); + aMcont = true; + } + + const bool transA = !aMcont; + + const int lda = (aMcont && aNcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); + + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + + NDArray::prepareSpecialUse({Y}, {pA, X}); + + // choose appropriate cuda gemm api depending on data types + if(typeDouble) { + status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)X->getSpecialBuffer(), incx, &beta, (double*)Y->getSpecialBuffer(), incy); + } + else if(typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)X->getSpecialBuffer(), incx, &betaF, (float*)Y->getSpecialBuffer(), incy); + } + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + + NDArray::registerSpecialUse({Y}, {pA, X}); + + if(pA != A) + delete pA; + } + + return Y; +} + +//////////////////////////////////////////////////////////////////////////// +// (X * Y) = Z[0] +NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) { + + int xLenDim(0), yLenDim(0); + + if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); + if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); + if(Z != nullptr && !Z->isScalar()) + throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); + + const auto length = X->lengthOf(); + + if(Y->lengthOf() != length) + throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); + + if(Z == nullptr) + Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); + + const Nd4jLong incx = X->strideAt(xLenDim); + const Nd4jLong incy = Y->strideAt(yLenDim); + + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + const auto zType = Z->dataType(); + + if(!X->isActualOnDeviceSide()) X->syncToDevice(); + if(!Y->isActualOnDeviceSide()) Y->syncToDevice(); + if(!Z->isActualOnDeviceSide()) Z->syncToDevice(); + + cudaStream_t* stream = X->getContext()->getCudaStream(); + + dim3 threadsPerBlock(512); + dim3 blocksPerGrid(1); + if (length > 512) + threadsPerBlock.x = math::nd4j_ceil(static_cast(length) / 512); + + NDArray::prepareSpecialUse({Z}, {X, Y}); + + //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); + + NDArray::registerSpecialUse({Z}, {X, Y}); + + return Z; +} + +////////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +template +static __global__ void batchedCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, + const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast< T3*>(vC); + + __shared__ bool betaPresent; + __shared__ int aRank, bRank, cRank, K; + __shared__ Nd4jLong cLen, totalThreads, *coords; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); + + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + + totalThreads = gridDim.x * blockDim.x; + aRank = shape::rank(aShapeInfo); + bRank = shape::rank(bShapeInfo); + cRank = shape::rank(cShapeInfo); + + betaPresent = beta; + + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); + + auto aCoords = coords + threadIdx.x * (aRank + bRank + cRank); + auto bCoords = aCoords + aRank; + auto cCoords = bCoords + bRank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); + + // calculate index of current batch + Nd4jLong batchInd; + if(cBatchDims != nullptr) + batchInd = shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims); + + // evaluate A coordinates + if(aBatchDims != nullptr) + shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + if(bBatchDims != nullptr) + shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords); + + if(betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } +} + +//////////////////////////////////////////////////////////////////////// +template +__host__ static void batchedGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { + + batchedCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); +} + +/////////////////////////////////////////////////////////////////// +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else + C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + + const int cRank = C->rankOf(); + + const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); + + const int threadsPerBlock = MAX_NUM_THREADS / 8; + const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * (aRank + bRank + cRank) + 128; + + PointersManager manager(A->getContext(), "MmulHelper::mmulNxN"); + + const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr); + + if(aRank > 2) + aBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), (aRank - 2) * sizeof(int))); + if(bRank > 2) + bBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), (bRank - 2) * sizeof(int))); + if(cRank > 2) + cBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int))); + + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); + + manager.synchronize(); + + return C; +} + + +/* ////////////////////////////////////////////////////////////////////////////// // MXN x N = M template @@ -106,309 +703,331 @@ __host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 &threadsPer usualCudaGemv<<>>(transA, M, N, alpha, vA, lda, vX, incx, beta, vY, incy); } - +*/ +/* ////////////////////////////////////////////////////////////////////////////// +MXK x KxN = MxN +C array must be in f order template -static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { +static __global__ void usualCudaGemm(const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { - T1* X = reinterpret_cast(const_cast(vX)); - T2* Y = reinterpret_cast(const_cast(vY)); - T3* Z = reinterpret_cast(vZ); + T1* A = reinterpret_cast(const_cast(vA)); + T2* B = reinterpret_cast(const_cast(vB)); + T3* C = reinterpret_cast(vC); - extern __shared__ char shmem[]; - auto pairwiseMul = reinterpret_cast(shmem); + __shared__ T3 alphaZ, betaZ; + __shared__ Nd4jLong strideArow, strideAcol, strideBrow, strideBcol; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid < length) - pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; + const int row = blockIdx.y * blockDim.y + threadIdx.y; + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + if(row == 0 && col == 0) { + + alphaZ = alpha; + betaZ = beta; + + if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } + if(transB) { strideBrow = ldb; strideBcol = 1; } else { strideBrow = 1; strideBcol = ldb; } + } __syncthreads(); - if(tid == 0) { - T3 sum = 0; - for(Nd4jLong i = 0; i < length; ++i) - sum = sum + pairwiseMul[i]; - *Z = (T3)alpha * sum + (T3)beta * *Z; - } -} + T3 val = 0; + if (row < M && col < N) + for (int i = 0; i < K; i++) + val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow + col * strideBcol]; -//////////////////////////////////////////////////////////////////////// -template -__host__ static void usualDot(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - usualCudaDot<<>>(length, alpha, vX, incx, vY, incy, beta, vZ); + C[row + col * ldc] = alphaZ * val + betaZ * C[row + col * ldc]; } ////////////////////////////////////////////////////////////////////////////// -// MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, double alpha, double beta, const char outOrder) { +template +__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); - if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); + usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); +} +*/ +////////////////////////////////////////////////////////////////////////// +/* +NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - auto M = A->sizeAt(0); - auto K = A->sizeAt(1); - auto N = B->sizeAt(1); + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); - if(C != nullptr && C->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); - if(B->sizeAt(0) != K) - throw std::runtime_error("MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(0) != M) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(1) != N) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } - if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + + // multiplication + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); + const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); + std::vector idxRanges(2 * C->rankOf()); + +// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) + for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { + + ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); + NDArray cSubArr = (*C)(idxRanges); + + if(aRank > bRank) { + NDArray aSubArr = (*A)(idxRanges); + mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); + } + else if(bRank > aRank) { + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); + } + else { + NDArray aSubArr = (*A)(idxRanges); + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); + } + } + + return C; +} +*/ + +////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +/* +NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + + const int cRank = C->rankOf(); + + const auto M = A->sizeAt(-2); + const auto K = A->sizeAt(-1); + const auto N = B->sizeAt(-1); NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); std::vector toDelete; - if(A->ews() != 1) { - pA = pA->dup('f'); + bool aMcont = M == 1 || A->strideAt(-2) == 1; + bool aKcont = K == 1 || A->strideAt(-1) == 1; + bool bKcont = K == 1 || B->strideAt(-2) == 1; + bool bNcont = N == 1 || B->strideAt(-1) == 1; + bool cMcont = M == 1 || C->strideAt(-2) == 1; + bool cNcont = N == 1 || C->strideAt(-1) == 1; + + if(!aMcont && !aKcont) { + pA = A->dup('c'); toDelete.push_back(pA); + aKcont = true; } - if(B->ews() != 1) { - pB = pB->dup('f'); + if(!bKcont && !bNcont) { + pB = B->dup('c'); toDelete.push_back(pB); + bNcont = true; } - if(C->ews() != 1) { - pC = pC->dup('f'); + std::vector permut(cRank); + if(!cMcont) { + std::iota(permut.begin(), permut.end(), 0); + permut[cRank - 2] = cRank - 1; + permut[cRank - 1] = cRank - 2; // swap two last dimensions [..., M,N] -> [..., N,M] + auto Cpermut = C->permute(permut); + pC = new NDArray('c', Cpermut.getShapeAsVector(), Cpermut.dataType(), A->getContext()); + pC->assign(Cpermut); toDelete.push_back(pC); + cMcont = true; } - if(pC->ordering() != 'f') { - auto temp = pA; - pA = new NDArray(pB ->permute({1,0})); - pB = new NDArray(temp->permute({1,0})); - pC = new NDArray(pC ->permute({1,0})); - toDelete.push_back(pA); - toDelete.push_back(pB); - toDelete.push_back(pC); - M = pA->sizeAt(0); - K = pA->sizeAt(1); - N = pB->sizeAt(1); - } - - const auto aOrder = pA->ordering(); - const auto bOrder = pB->ordering(); - - const bool transA = aOrder != 'f'; - const bool transB = bOrder != 'f'; - - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - const int lda = aOrder == 'f' ? M : K; - const int ldb = bOrder == 'f' ? K : N; - const int ldc = M; // cOrder == 'f' ? M : N; const auto aType = pA->dataType(); const auto bType = pB->dataType(); const auto cType = pC->dataType(); - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - const int deviceId = AffinityManager::currentDeviceId(); - const int major = Environment::getInstance()->capabilities()[deviceId].first(); + bool badTypes = false; + cudaDataType_t cudaType, cudaAType, cudaBType, cudaCType; + + if(ABC && aType == DataType::HALF) { + cudaType = cudaAType = cudaBType = cudaCType = CUDA_R_16F; + } + else if(ABC && aType == DataType::FLOAT32) { + cudaType = cudaAType = cudaBType = cudaCType = CUDA_R_32F; + } + else if(ABC && aType == DataType::DOUBLE) { + cudaType = cudaAType = cudaBType = cudaCType = CUDA_R_64F; + } + else if(AB && cType == DataType::FLOAT32 && aType == DataType::INT8) { + cudaType = cudaCType = CUDA_R_32F; + cudaAType = cudaBType = CUDA_R_8I; + } + else if(AB && cType == DataType::FLOAT32 && aType == DataType::HALF) { + cudaType = cudaCType = CUDA_R_32F; + cudaAType = cudaBType = CUDA_R_16F; + } + else + badTypes = true; + + const int bS = pC->lengthOf() / (M*N); + + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(cRank, {-2, -1}); NDArray::prepareSpecialUse({pC}, {pA, pB}); - // choose appropriate cuda gemm api depending on data types - if(ABC && aType == DataType::DOUBLE) { - status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)pB->getSpecialBuffer(), ldb, &beta, (double*)pC->getSpecialBuffer(), ldc); - } - else if(ABC && aType == DataType::FLOAT32) { - float alphaF(alpha), betaF(beta); - status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); - } - else if(ABC && aType == DataType::HALF && major >= 6) { - float16 alphaH(alpha), betaH(beta); - status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); - } - else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); - } - else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); - } - else { - dim3 threadsPerBlock(N, M); - dim3 blocksPerGrid(1, 1); - if (M*N > 512){ - threadsPerBlock.x = threadsPerBlock.y = 512; - blocksPerGrid.x = math::nd4j_ceil(static_cast(N) / threadsPerBlock.x); // cols - blocksPerGrid.y = math::nd4j_ceil(static_cast(M) / threadsPerBlock.y); // rows + if(!badTypes) { + + std::vector subArrOffsets(bS); + std::vector subArrShapeInfo(shape::shapeInfoLength(2)); // all sub-arrays have rank = 2 + + std::vector aSubArrs(bS), bSubArrs(bS), cSubArrs(bS); + + if(aRank > 2) + shape::calcSubArrShapeAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); + for (int i = 0; i < bS; ++i) + aSubArrs[i] = aRank == 2 ? pA->getSpecialBuffer() : pA->getSpecialBuffer() + subArrOffsets[i] * pA->sizeOfT(); + + if(bRank > 2) + shape::calcSubArrShapeAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); + for (int i = 0; i < bS; ++i) + bSubArrs[i] = bRank == 2 ? pB->getSpecialBuffer() : pB->getSpecialBuffer() + subArrOffsets[i] * pB->sizeOfT(); + + shape::calcSubArrShapeAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); + for (int i = 0; i < bS; ++i) + cSubArrs[i] = pC->getSpecialBuffer() + subArrOffsets[i] * pC->sizeOfT(); + + PointersManager manager(A->getContext(), "mmulNxN"); + + const void** aSubArrsCuda = reinterpret_cast(manager.replicatePointer(aSubArrs.data(), aSubArrs.size() * sizeof(void*))); + const void** bSubArrsCuda = reinterpret_cast(manager.replicatePointer(bSubArrs.data(), bSubArrs.size() * sizeof(void*))); + void** cSubArrsCuda = reinterpret_cast< void **>(manager.replicatePointer(cSubArrs.data(), cSubArrs.size() * sizeof(void*))); + + const bool transA = !aMcont; + const bool transB = !bKcont; + + const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(-2) : pA->strideAt(-1); + const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(-2) : pB->strideAt(-1); + const int ldc = (cMcont && cNcont) ? M : C != pC ? pC->strideAt(-2) : pC->strideAt(-1); + + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + union Coeff {__half _h; float _f; double _d; }; + Coeff uAlpha, uBeta; + + if(cudaType == CUDA_R_16F) { + uAlpha._h = alpha; + uBeta._h = beta; + } + else if(cudaType == CUDA_R_32F) { + uAlpha._f = alpha; + uBeta._f = beta; + } + else if(cudaType == CUDA_R_64F) { + uAlpha._d = alpha; + uBeta._d = beta; } - //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES) + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &uAlpha, aSubArrsCuda, cudaAType, lda, bSubArrsCuda, cudaBType, ldb, &uBeta, cSubArrsCuda, cudaCType, ldc, bS, cudaType, CUBLAS_GEMM_DEFAULT); + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", cudaResult); } + else { - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + std::vector idxRanges(2 * pC->rankOf()); - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + for(Nd4jLong i = 0; i < bS; ++i) { + + ShapeUtils::evalIdxRangesForSubArr(i, pC->getShapeInfo(), dimsToExclude, idxRanges.data()); + NDArray cSubArr = (*pC)(idxRanges); + + if(aRank > bRank) { + NDArray aSubArr = (*pA)(idxRanges); + mmulMxM(&aSubArr, pB, &cSubArr, 1., 0., pC->ordering()); + } + else if(bRank > aRank) { + NDArray bSubArr = (*pB)(idxRanges); + mmulMxM(pA, &bSubArr, &cSubArr, 1., 0, pC->ordering()); + } + else { + NDArray aSubArr = (*pA)(idxRanges); + NDArray bSubArr = (*pB)(idxRanges); + mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., pC->ordering()); + } + } + } NDArray::registerSpecialUse({pC}, {pA, pB}); - if(C->ews() != 1) - C->assign(pC); + if(C != pC) + C->assign(pC->permute(permut)); for(int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; return C; } - -//////////////////////////////////////////////////////////////////////////// -// MXN x N = M -NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { - - int xLenDim, yLenDim(0); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); - if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); - - const auto M = A->sizeAt(0); - const auto N = A->sizeAt(1); - - if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); - if(X->lengthOf() != N) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); - if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array has wrong length !"); - - if(Y == nullptr) - Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - NDArray *pA(const_cast(A)); - - if(A->ews() != 1) - pA = pA->dup('f'); - - const bool transA = pA->ordering() == 'c'; - - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - - int lda, lta; - if(transA) { lda = N; lta = M; } - else { lda = M; lta = N; } - - const int incx = X->stridesOf()[xLenDim]; - const int incy = Y->stridesOf()[yLenDim]; - - const auto aType = pA->dataType(); - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - - const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); - - NDArray::prepareSpecialUse({Y}, {pA, X}); - - // choose appropriate cuda gemm api depending on data types - if(AXY && aType == DataType::DOUBLE) { - status = cublasDgemv(*handle, transAblas, lda, lta, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)X->getSpecialBuffer(), incx, &beta, (double*)Y->getSpecialBuffer(), incy); - } - else if(AXY && aType == DataType::FLOAT32) { - float alphaF(alpha), betaF(beta); - status = cublasSgemv(*handle, transAblas, lda, lta, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)X->getSpecialBuffer(), incx, &betaF, (float*)Y->getSpecialBuffer(), incy); - } - else { - dim3 threadsPerBlock(M); - dim3 blocksPerGrid(1); - if (M > 512){ - threadsPerBlock.x = 512; - blocksPerGrid.x = math::nd4j_ceil(static_cast(M) / threadsPerBlock.x); // rows - } - //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES) - } - - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); - - NDArray::registerSpecialUse({Y}, {pA, X}); - - if(pA != A) - delete pA; - - return Y; -} - -//////////////////////////////////////////////////////////////////////////// -// (X * Y) = Z[0] -NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) { - - int xLenDim(0), yLenDim(0); - - if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); - if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); - if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); - - const auto length = X->lengthOf(); - - if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); - - if(Z == nullptr) - Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - - const Nd4jLong incx = X->stridesOf()[xLenDim]; - const Nd4jLong incy = Y->stridesOf()[yLenDim]; - - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - const auto zType = Z->dataType(); - - if(!X->isActualOnDeviceSide()) X->syncToDevice(); - if(!Y->isActualOnDeviceSide()) Y->syncToDevice(); - if(!Z->isActualOnDeviceSide()) Z->syncToDevice(); - - cudaStream_t* stream = X->getContext()->getCudaStream(); - - dim3 threadsPerBlock(512); - dim3 blocksPerGrid(1); - if (length > 512) - threadsPerBlock.x = math::nd4j_ceil(static_cast(length) / 512); - - NDArray::prepareSpecialUse({Z}, {X, Y}); - - //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); - - NDArray::registerSpecialUse({Z}, {X, Y}); - - return Z; -} +*/ //BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index b50104bee..ab97ad137 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -184,69 +184,6 @@ NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray #endif -////////////////////////////////////////////////////////////////////////// -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - - // input ranks validation - if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); - else if (aRank == bRank ) { - for(int i = 0; i < aRank - 2; ++i) - if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - } - - if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - - // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); - - if(C != nullptr ) { - if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); - } - else { - C = new NDArray(outOrder, cExpectedShape, B->dataType()); - } - - - // multiplication - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); - std::vector idxRanges(2 * C->rankOf()); - -// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) - for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - - ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*C)(idxRanges); - - if(aRank > bRank) { - NDArray aSubArr = (*A)(idxRanges); - mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); - } - else if(bRank > aRank) { - NDArray bSubArr = (*B)(idxRanges); - mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); - } - else { - NDArray aSubArr = (*A)(idxRanges); - NDArray bSubArr = (*B)(idxRanges); - mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); - } - } - - return C; -} - ////////////////////////////////////////////////////////////////////////// nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, nd4j::NDArray* C , const double alpha, const double beta, const char outOrder) { diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index b8cbf1b37..0a581d718 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -889,7 +889,8 @@ namespace shape { */ ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0); - ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector& indices); + ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *indices, Nd4jLong baseOffset = 0); + ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *indices, Nd4jLong baseOffset = 0); ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); @@ -900,7 +901,13 @@ namespace shape { * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] */ ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); + ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); + ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims); @@ -909,7 +916,13 @@ namespace shape { * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned */ ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); + ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords); + ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords); ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords, const int dimsSize, const int* tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -1748,6 +1761,34 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLo return index; } +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords) { + + Nd4jLong index, shift = 1;; + + index = coords[shapeInfo[0] - 1]; + for(uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * coords[i - 2]; + } + + return index; +} + +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords) { + + Nd4jLong index, shift = 1;; + + index = coords[shapeInfo[0] - 1]; + for(uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * coords[i - 2]; + } + + return index; +} + ////////////////////////////////////////////////////////////////////// INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) { @@ -1762,6 +1803,19 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, return index; } +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords, const int dimsSize, const int* tadDims) { + + Nd4jLong index, shift = 1;; + + index = coords[tadDims[dimsSize - 1]]; + for(uint i = dimsSize - 1; i >= 1; --i) { + shift *= shapeInfo[tadDims[i]]; + index += shift * coords[i - 1]; + } + + return index; +} + template INLINEDEF _CUDA_HD void fill(T* buffer, T value, Nd4jLong length) { @@ -3202,18 +3256,28 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong } ////////////////////////////////////////////////////////////////////////// -INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector& indices) { +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) { - Nd4jLong offset = 0; + Nd4jLong offset = baseOffset; for(uint i = 1; i <= shapeInfo[0]; ++i) if(shapeInfo[i] != 1) - offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i]; + offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; return offset; } +////////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; + + for(uint i = 1; i <= shapeInfo[0]; ++i) + if(shapeInfo[i] != 1) + offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; + + return offset; +} /** @@ -3957,9 +4021,13 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, oldStart = oldStop++; } - newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order - newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews - newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type + // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank) + for (int i = newStart; i < newRank; ++i) + newStrides[i] = 1; + + newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order + newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews + newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type return true; } @@ -4695,6 +4763,26 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, coords[0] = index; // last iteration } +////////////////////////////////////////////////////////////////////// +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords) { + + for(uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); + index /= static_cast(shapeInfo[i]); + } + coords[0] = static_cast(index); // last iteration +} + +////////////////////////////////////////////////////////////////////// +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords) { + + for(uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); + index /= static_cast(shapeInfo[i]); + } + coords[0] = static_cast(index); // last iteration +} + ////////////////////////////////////////////////////////////////////// INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) { @@ -4705,6 +4793,16 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL coords[0] = index; // last iteration } +////////////////////////////////////////////////////////////////////// +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) { + + for(uint i = dimsSize - 1; i > 0; --i) { + coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]]; + index /= shapeInfo[1 + tadDims[i]]; + } + coords[tadDims[0]] = index; // last iteration +} + ////////////////////////////////////////////////////////////////////// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { diff --git a/libnd4j/include/loops/broadcasting_bool.h b/libnd4j/include/loops/broadcasting_bool.h index 3b0958be1..7ba5fa9eb 100644 --- a/libnd4j/include/loops/broadcasting_bool.h +++ b/libnd4j/include/loops/broadcasting_bool.h @@ -65,13 +65,14 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); template static __device__ void transformInverseCuda( @@ -81,13 +82,14 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); #else @@ -98,6 +100,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, @@ -114,6 +117,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, @@ -141,6 +145,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, @@ -157,6 +162,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.cpp b/libnd4j/include/loops/cpu/broadcasting_bool.cpp index 7a3eb1e31..8d62b9506 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.cpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.cpp @@ -39,6 +39,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, @@ -53,6 +54,7 @@ namespace functions { yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, xTadShapeInfo, @@ -69,6 +71,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, @@ -83,6 +86,7 @@ namespace functions { yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, xTadShapeInfo, @@ -99,6 +103,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, @@ -111,6 +116,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -155,7 +161,7 @@ namespace functions { PRAGMA_OMP_SIMD for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f]); + oZ[f] = OpType::op(oX[f], y[f], extraParams); } } else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { @@ -165,7 +171,7 @@ namespace functions { PRAGMA_OMP_SIMD for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams); }; } else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { @@ -179,7 +185,7 @@ namespace functions { PRAGMA_OMP_SIMD for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset]); + oZ[offset] = OpType::op(oX[offset], y[offset], extraParams); } }; } @@ -197,7 +203,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset]); + oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams); } }; } @@ -215,7 +221,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset]); + oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams); } }; @@ -234,7 +240,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset]); + oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams); } }; } @@ -255,7 +261,7 @@ namespace functions { auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams); } }; } @@ -270,6 +276,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *yTadShapeInfo, @@ -282,6 +289,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -326,7 +334,7 @@ namespace functions { PRAGMA_OMP_SIMD for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f]); + oZ[f] = OpType::op(x[f], oY[f], extraParams); } } else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { @@ -336,7 +344,7 @@ namespace functions { PRAGMA_OMP_SIMD for (uint f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams); } } else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { @@ -351,7 +359,7 @@ namespace functions { PRAGMA_OMP_SIMD for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset]); + oZ[offset] = OpType::op(x[offset], oY[offset], extraParams); } } } @@ -370,7 +378,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset]); + oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams); } } } @@ -389,7 +397,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset]); + oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams); } } } @@ -408,7 +416,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset]); + oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams); } } } @@ -430,7 +438,7 @@ namespace functions { auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams); } } } diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index af354a2e2..d5a45ceec 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -40,10 +40,11 @@ static __global__ void broadcastBoolSimple( Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// @@ -55,10 +56,11 @@ static __global__ void broadcastBoolInverseSimple( Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - functions::broadcast::BroadcastBool::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + functions::broadcast::BroadcastBool::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } namespace functions { @@ -66,15 +68,15 @@ namespace functions { ////////////////////////////////////////////////////////////////////////// template template - __host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + __host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// template - __host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) + __host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DEBUG_KERNEL(stream, opNum); } @@ -82,15 +84,15 @@ namespace functions { ////////////////////////////////////////////////////////////////////////// template template - __host__ void BroadcastBool::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + __host__ void BroadcastBool::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// template - __host__ void BroadcastBool::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) + __host__ void BroadcastBool::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DEBUG_KERNEL(stream, opNum); } @@ -102,6 +104,7 @@ namespace functions { void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -113,6 +116,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -140,7 +144,7 @@ namespace functions { if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams); } else { // it is expected that x and z tads and y array all have the same length @@ -149,7 +153,7 @@ namespace functions { auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams); } } } @@ -162,6 +166,7 @@ namespace functions { void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -173,6 +178,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -208,7 +214,7 @@ namespace functions { if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams); } else { // it is expected that x and z tads and y array all have the same length @@ -217,7 +223,7 @@ namespace functions { auto yOffset = shape::getIndexOffset(i, yShapeInfo); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams); } } } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 92fd58d7a..5108ba4a7 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -45,13 +45,13 @@ (2, LessThan),\ (3, Epsilon),\ (4, GreaterThanOrEqual),\ - (5, LessThanOrEqual),\ + (5, MatchCondition) ,\ (6, NotEqualTo),\ (7, And),\ (8, Or),\ (9, Xor) ,\ - (10, Not) - + (10, Not) ,\ + (11, LessThanOrEqual) #define BROADCAST_OPS \ (0, Add), \ @@ -169,7 +169,9 @@ (53, GELU) ,\ (54, GELUDerivative), \ (55, PreciseGELU) ,\ - (56, PreciseGELUDerivative) + (56, PreciseGELUDerivative), \ + (57, Mish),\ + (58, MishDerivative) // these ops return one of FLOAT data types #define TRANSFORM_FLOAT_OPS \ @@ -196,12 +198,13 @@ (2, LessThan),\ (3, Epsilon),\ (4, GreaterThanOrEqual),\ - (5, LessThanOrEqual),\ + (5, MatchCondition) ,\ (6, NotEqualTo),\ (7, And),\ (8, Or),\ (9, Xor) ,\ - (10, Not) + (10, Not) ,\ + (11, LessThanOrEqual) #define SCALAR_OPS \ (0, Add),\ @@ -339,12 +342,13 @@ (2, LessThan),\ (3, Epsilon),\ (4, GreaterThanOrEqual),\ - (5, LessThanOrEqual),\ + (5, MatchCondition) ,\ (6, NotEqualTo),\ (7, And),\ (8, Or),\ (9, Xor) ,\ - (10, Not) + (10, Not) ,\ + (11, LessThanOrEqual) #define PAIRWISE_TRANSFORM_OPS \ (0, Add),\ diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 5da74860b..ea1f20d34 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -32,6 +32,7 @@ #include #include #include +#include //#include #include @@ -111,7 +112,7 @@ namespace nd4j { */ int prepareOutputs(Context& block); - //std::vector* calculateOutputShape(std::vector* inputShape, nd4j::graph::Block& block); + virtual samediff::EmptyHandling emptyHandling(); public: // for special cases, like BooleanOps DeclarableOp(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/I18NRoute.java b/libnd4j/include/ops/declarable/EmptyHandling.h similarity index 58% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/I18NRoute.java rename to libnd4j/include/ops/declarable/EmptyHandling.h index e65170872..c25fea498 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/I18NRoute.java +++ b/libnd4j/include/ops/declarable/EmptyHandling.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,24 +14,19 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.ui.play.staticroutes; +// +// @author raver119@gmail.com +// -import org.deeplearning4j.ui.i18n.I18NProvider; -import play.mvc.Result; +#ifndef SAMEDIFF_EMPTYHANDLING_H +#define SAMEDIFF_EMPTYHANDLING_H -import java.util.function.Function; - -import static play.mvc.Results.ok; - -/** - * Route for global internationalization setting - * - * @author Alex Black - */ -public class I18NRoute implements Function { - @Override - public Result apply(String s) { - I18NProvider.getInstance().setDefaultLanguage(s); - return ok(); - } +namespace samediff { + enum EmptyHandling { + EMPTY_SKIP = 1, + EMPTY_EXCEPTION = 2, + EMPTY_EXECUTE = 3 + }; } + +#endif //SAMEDIFF_EMPTYHANDLING_H diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp index 563965147..37fa1350a 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp @@ -27,9 +27,12 @@ namespace nd4j { namespace ops { BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) { - auto input = INPUT_VARIABLE(0); + // in case of empty input there's nothing to do + if (input->isEmpty()) + return ND4J_STATUS_TRUE; + bool isNonDecreasing = true; nd4j::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing); diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp index a238eedac..a41852394 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp @@ -27,9 +27,12 @@ namespace nd4j { namespace ops { BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) { - auto input = INPUT_VARIABLE(0); + // in case of empty input there's nothing to do + if (input->isEmpty()) + return ND4J_STATUS_TRUE; + bool isStrictlyIncreasing = true; nd4j::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index be9edd538..26d03358a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -35,13 +35,13 @@ namespace ops { CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { - + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - + int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides width int pH = INT_ARG(4); // paddings height @@ -60,8 +60,8 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC}); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); @@ -99,21 +99,21 @@ DECLARE_SHAPE_FN(conv2d) { if(!isNCHW) { indIOioC = 3; indIiH = 1; } - else { + else { indIOioC = 1; indIiH = 2; - } + } const int bS = inputShapeInfo[1]; // batch size const int iH = inputShapeInfo[indIiH+1]; // input height const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels + const int iC = inputShapeInfo[indIOioC+1]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC}); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) + if (biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - + Nd4jLong* outputShapeInfo = nullptr; ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); @@ -132,9 +132,9 @@ DECLARE_SHAPE_FN(conv2d) { outputShapeInfo[3] = oW; outputShapeInfo[4] = oC; } - + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - + return SHAPELIST(CONSTANT(outputShapeInfo)); } @@ -153,18 +153,18 @@ DECLARE_SHAPE_FN(conv2d) { } -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { - + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - + int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width int sH = INT_ARG(2); // strides height @@ -195,7 +195,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); - + return Status::OK(); } @@ -229,14 +229,14 @@ DECLARE_SHAPE_FN(conv2d_bp) { if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; } - else { + else { indIOioC = 1; indIiH = 2; indOoH = 2; - } + } const int bS = inputShapeInfo[1]; // batch size const int iH = inputShapeInfo[indIiH+1]; // input height const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels + const int iC = inputShapeInfo[indIOioC+1]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels int trueoH, trueoW; // true output height, width @@ -248,27 +248,27 @@ DECLARE_SHAPE_FN(conv2d_bp) { REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); if(biasShapeInfo) { auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } + } return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { - + auto gradIShape = INPUT_VARIABLE(0); // [4] auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - + int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width int sH = INT_ARG(2); // strides height @@ -284,9 +284,9 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - // create empty conv2d input array + // create empty conv2d input array std::vector gradIShapeAsVector(rank); for(int i = 0; i < rank; ++i) gradIShapeAsVector[i] = gradIShape->e(i); @@ -306,7 +306,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); - + return Status::OK(); } @@ -325,8 +325,8 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next const int rank = 4; - - REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]); + + REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); @@ -345,18 +345,18 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; } - else { + else { indIOioC = 1; indIiH = 2; indOoH = 2; - } + } std::vector gradIShape = INPUT_VARIABLE(0)->template asVectorT(); const int bS = gradIShape[0]; // batch size const int iH = gradIShape[indIiH]; // input height const int iW = gradIShape[indIiH+1]; // input width - const int iC = gradIShape[indIOioC]; // input channels + const int iC = gradIShape[indIOioC]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels - + int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); @@ -364,8 +364,8 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC}); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - - Nd4jLong* gradIshapeInfo(nullptr); + + Nd4jLong* gradIshapeInfo(nullptr); ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); gradIshapeInfo[0] = rank; @@ -380,7 +380,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { gradIshapeInfo[3] = iW; gradIshapeInfo[4] = iC; } - + ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, shape::order(gradOShapeInfo)); return SHAPELIST(CONSTANT(gradIshapeInfo)); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 8ed4a908e..3b8c51bc7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -66,10 +66,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { if(!isNCHW) output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - if(isSameMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass + if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index b926a3a1a..1baccbe0e 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -67,16 +67,16 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { if(!isNCDHW) output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - if(isSameMode) //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass + if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); + NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); //----- calculation of output -----// // NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] // NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW] nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] - ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] + ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] //----- add biases if required -----// if(bias) diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index 3e81be61d..392cd3128 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always - NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -300,7 +300,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { DECLARE_SHAPE_FN(sconv2d_bp) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC] always Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp index fa25b1067..7afb24bd7 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp @@ -39,6 +39,7 @@ namespace nd4j { double extrapolationVal = 0.; auto newImageSize = INPUT_VARIABLE(3); + REQUIRE_TRUE(output->dataType() == image->dataType(), 0, "crop_and_resize: Source images and output should have the same data type."); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); //REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive."); //width = int(newImageSize->getScalar(0)); @@ -57,6 +58,7 @@ namespace nd4j { DECLARE_SHAPE_FN(crop_and_resize) { auto in = inputShape->at(0); + auto boxShape = inputShape->at(1); Nd4jLong outputShape[4]; @@ -68,22 +70,22 @@ namespace nd4j { width = newImageSize->e(0); height = newImageSize->e(1); - outputShape[0] = in[1]; + outputShape[0] = boxShape[1]; outputShape[1] = width; outputShape[2] = height; outputShape[3] = in[4]; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(DataType::FLOAT32, shape::order(in), outputShape, 4))); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(in), shape::order(in), outputShape, 4))); } DECLARE_TYPES(crop_and_resize) { getOpDescriptor() ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) // ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {FLOAT32}) // as TF + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(2, {ALL_INTS}) ->setAllowedInputTypes(3, {ALL_INTS}) - ->setAllowedOutputTypes({FLOAT32}); // as TF + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); // as TF // ->setAllowedOutputTypes({ALL_FLOATS}); } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/draw_bounding_boxes.cpp index 7610e6161..8578f287e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/draw_bounding_boxes.cpp @@ -29,9 +29,31 @@ namespace nd4j { auto images = INPUT_VARIABLE(0); auto boxes = INPUT_VARIABLE(1); - auto colors = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); + auto colors = (NDArray*) nullptr; + if (block.width() > 2) // TF v.1.x ommits color set for boxes, and use color 1.0 for fill up + colors = INPUT_VARIABLE(2); // but v.2.y require color set + + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(images->dataType() == output->dataType(), 0, "draw_bounding_boxes: Input and Output types " + "should be equals, but %d and %d occured.", + (int)images->dataType(), (int)output->dataType()); + REQUIRE_TRUE(images->rankOf() == 4, 0, "draw_bounding_boxes: Images input should be 4D tensor, but %i occured.", + images->rankOf()); + REQUIRE_TRUE(boxes->rankOf() == 3, 0, "draw_bounding_boxes: Boxes should be 3D tensor, but %i occured.", + boxes->rankOf()); + if (colors) { + + REQUIRE_TRUE(colors->rankOf() == 2, 0, "draw_bounding_boxes: Color set should be 2D matrix, but %i occured.", + colors->rankOf()); + REQUIRE_TRUE(colors->sizeAt(1) >= images->sizeAt(3), 0, "draw_bounding_boxes: Color set last dim " + "should be not less than images depth, but " + "%lld and %lld occured.", + colors->sizeAt(1), images->sizeAt(3)); + } + REQUIRE_TRUE(boxes->sizeAt(0) == images->sizeAt(0), 0, "draw_bounding_boxes: Batches for images and boxes " + "should be the same, but %lld and %lld occured.", + images->sizeAt(0), boxes->sizeAt(0)); helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp index e19c030c4..6d24827e5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp @@ -46,6 +46,8 @@ namespace nd4j { max = &m2; } auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same"); + int numBits = 8; if (block.getIArguments() && block.getIArguments()->size()) numBits = INT_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index e5873d9dd..5874d2f81 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -40,8 +40,10 @@ namespace nd4j { REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be" "%lld, but %lld occurs.", depth, max->lengthOf()); - auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same"); + int numBits = 8; if (block.getIArguments() && block.getIArguments()->size()) numBits = INT_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/image_resize.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/image_resize.cpp new file mode 100644 index 000000000..cf26b69b3 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/image_resize.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_image_resize) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(image_resize, 2, 1, false, 0, 0) { + + auto image = INPUT_VARIABLE(0); + auto size = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + bool preserveAspectRatio = false; // - default value + bool antialias = false; + REQUIRE_TRUE(size->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %lld.", size->lengthOf()); + width = size->e(0); + height = size->e(1); + if (block.getBArguments()->size()) { + preserveAspectRatio = B_ARG(0); + if (block.getBArguments()->size() > 1) + antialias = B_ARG(1); + } + + auto method = helpers::ImageResizeMethods::kResizeBilinear; + if (block.numI() == 1) { + method = (helpers::ImageResizeMethods)INT_ARG(0); + } + + return helpers::resizeFunctor(block.launchContext(), image, width, height, method, preserveAspectRatio, antialias, output); + } + + DECLARE_SHAPE_FN(image_resize) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + + int width; + int height; + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); + outputShape[0] = 4; + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; + } + DECLARE_TYPES(image_resize) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index e2fe58b7a..c56e32f31 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -37,6 +37,9 @@ namespace nd4j { else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); + if (boxes->isEmpty() || scales->isEmpty()) + return Status::OK(); + REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp new file mode 100644 index 000000000..0c1aeba61 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_resize_bicubic) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) { + + auto image = INPUT_VARIABLE(0); + auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width) + size->syncToHost(); + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + auto inRank = image->rankOf(); + REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); + REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); + REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive."); + width = size->e(1); + height = size->e(0); + REQUIRE_TRUE(width > 0 , 0, "resize_bicubic: picture width should be positive 32 bit integer, but %i given", width); + REQUIRE_TRUE(height > 0 , 0, "resize_bicubic: picture height should be positive 32 bit integer, but %i given", height); + //REQUIRE_TRUE(image->sizeAt(1) > 3 && image->sizeAt(2) > 3, 0, "resize_cubic: To use bicubic algorithm need at least 16 pixels as source."); + REQUIRE_TRUE(width > 3 && height > 3, 0, "resize_bicubic: To use bicubic algorithm need at least 16 pixels as target."); + REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_bicubic: Only non-zero images allowed to processing."); +// auto method = 1; //kResizeBilinear; +// if (block.numI() == 1) { +// method = INT_ARG(0); +// } + auto alignCorners = false; + auto halfPixelAlign = false; + if (block.numB() > 0) { + alignCorners = block.getBArguments()->at(0); + if (block.numB()> 1) + halfPixelAlign = block.getBArguments()->at(1); + } + REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners"); + + auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); + auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); + + return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target); + } + + DECLARE_SHAPE_FN(resize_bicubic) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + int width; + int height; + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); + + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } + else { + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; + } + DECLARE_TYPES(resize_bicubic) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {DataType::INT32}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp index 0f3a0a490..f60f14fdc 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp @@ -33,6 +33,15 @@ namespace nd4j { int width; int height; bool center = false; // - default value + auto inRank = image->rankOf(); + REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + image->rankOf()); + REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); + + auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); + auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); + if (block.width() > 1) { auto newImageSize = INPUT_VARIABLE(1); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); @@ -51,7 +60,7 @@ namespace nd4j { center = 0 != INT_ARG(2); } - return helpers::resizeBilinearFunctor(block.launchContext(), image, width, height, center, output); + return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); } DECLARE_SHAPE_FN(resize_bilinear) { @@ -59,6 +68,10 @@ namespace nd4j { auto in = inputShape->at(0); Nd4jLong* outputShape; + auto inRank = shape::rank(in); + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + inRank); int width; int height; @@ -75,12 +88,19 @@ namespace nd4j { height = INT_ARG(1); } - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); - outputShape[0] = 4; - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } + else { // input shape is 3D, so result also should be 3D + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); shapeList->push_back(CONSTANT(outputShape)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp index b36c56579..249c504dc 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp @@ -51,16 +51,25 @@ namespace nd4j { if (block.numI() == 3) center = 0 != INT_ARG(2); } + auto inRank = image->rankOf(); + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); + REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); + auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); + auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); - return helpers::resizeNeighborFunctor(block.launchContext(), image, width, height, center, output); + return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); } DECLARE_SHAPE_FN(resize_nearest_neighbor) { auto shapeList = SHAPELIST(); auto in = inputShape->at(0); - + auto inRank = shape::rank(in); Nd4jLong* outputShape; + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + inRank); + int width; int height; if (block.width() > 1) { @@ -75,13 +84,20 @@ namespace nd4j { width = INT_ARG(0); height = INT_ARG(1); } - - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); - outputShape[0] = 4; - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; + + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } + else { // input shape is 3D, so result also should be 3D + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); shapeList->push_back(CONSTANT(outputShape)); diff --git a/libnd4j/include/ops/declarable/generic/shape/create.cpp b/libnd4j/include/ops/declarable/generic/shape/create.cpp new file mode 100644 index 000000000..e743a5cad --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/shape/create.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_shapes_of) + +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { + auto init = block.numB() > 0 ? B_ARG(0) : true; + + if (init) + OUTPUT_VARIABLE(0)->nullify(); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(create) { + auto shapeInput = INPUT_VARIABLE(0); + auto order = (char) INT_ARG(0); + auto dtype = DataTypeUtils::fromInt(INT_ARG(1)); + + REQUIRE_TRUE(order == 'c' || order == 'f', 0, "create: order must be either c or f"); + + auto shape = shapeInput->getBufferAsVector(); + + return SHAPELIST(nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, order, shape)); + } + + DECLARE_TYPES(create) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes(nd4j::DataType::ANY); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 590d99308..f6ed43f47 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -1594,7 +1595,7 @@ namespace nd4j { DECLARE_CUSTOM_OP(reduce_logsumexp, 1, 1, false, 0, 0); #endif - /** + /** * This op make bilinear or nearest neighbor interpolated resize for given tensor * * input array: @@ -1616,7 +1617,7 @@ namespace nd4j { DECLARE_CUSTOM_OP(crop_and_resize, 4, 1, false, -1, -1); #endif - /** + /** * This op make bilinear interpolated resize for given tensor * * input array: @@ -1637,7 +1638,7 @@ namespace nd4j { DECLARE_CUSTOM_OP(resize_bilinear, 1, 1, false, 0, -2); #endif - /** + /** * This op make nearest neighbor interpolated resize for given tensor * * input array: @@ -1649,7 +1650,7 @@ namespace nd4j { * 1 - new height * * output array: - * the 4D-Tensor with calculated backproped dots + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) * * CAUTION: either size tensor or a pair of int params should be provided. */ @@ -1658,21 +1659,57 @@ namespace nd4j { DECLARE_CUSTOM_OP(resize_nearest_neighbor, 1, 1, false, 0, -2); #endif - /** - * This op calculates backprop dot for two tensors along given dimensions + /** + * This op make bicubic interpolated resize for given tensor * * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) * * output array: - * the tensor with calculated backproped dots + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) + * + */ + #if NOT_EXCLUDED(OP_resize_bicubic) + DECLARE_CUSTOM_OP(resize_bicubic, 1, 1, false, 0, -2); + #endif + + /** + * This op make interpolated resize for given tensor with given algorithm. + * Supported algorithms are bilinear, bicubic, nearest_neighbor. + * Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) + * + * optional int args: + * 0 - algorithm - bilinear by default + * optional bool args: + * 0 - preserve_aspect_ratio - default False + * 1 - antialias - default False + * + * output array: + * the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels}) + * + */ + + #if NOT_EXCLUDED(OP_image_resize) + DECLARE_CUSTOM_OP(image_resize, 2, 1, false, 0, 0); + #endif + + /** + * Copy a tensor setting everything outside a central band in each innermost matrix + * + * input array: + * x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN + * + * int arguments: + * lower band + * upper band + * + * output array: + * matrix with given bands between lower and upper diagonals * */ @@ -1684,7 +1721,8 @@ namespace nd4j { #if NOT_EXCLUDED(OP_Assert) DECLARE_OP(Assert, 1, 1, false); #endif - /* + + /** * image.non_max_suppression op. * input: * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type diff --git a/libnd4j/include/ops/declarable/headers/shape.h b/libnd4j/include/ops/declarable/headers/shape.h index 9e0c19a7b..3d47c24bf 100644 --- a/libnd4j/include/ops/declarable/headers/shape.h +++ b/libnd4j/include/ops/declarable/headers/shape.h @@ -99,6 +99,22 @@ namespace nd4j { #if NOT_EXCLUDED(OP_evaluate_reduction_shape) DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); #endif + + /** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ + #if NOT_EXCLUDED(OP_create) + DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1); + #endif } } diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index 68cfc8d05..81695c9ac 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -69,8 +69,8 @@ namespace nd4j { eKH = kH; eKW = kW; } else { - eKH = kH + (kH - 1) * (dH - 1); - eKW = kW + (kW - 1) * (dW - 1); + eKH = (kH - 1) * dH + 1; + eKW = (kW - 1) * dW + 1; } pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 @@ -84,9 +84,9 @@ namespace nd4j { eKH = kH; eKW = kW; } else { - eKD = kD + (kD - 1) * (dD - 1); - eKH = kH + (kH - 1) * (dH - 1); - eKW = kW + (kW - 1) * (dW - 1); + eKD = (kD - 1) * dD + 1; + eKH = (kH - 1) * dH + 1; + eKW = (kW - 1) * dW + 1; } pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 @@ -107,8 +107,8 @@ namespace nd4j { ekH = kH; ekW = kW; } else { - ekH = kH + (kH - 1) * (dH - 1); - ekW = kW + (kW - 1) * (dW - 1); + ekH = (kH - 1) * dH + 1; + ekW = (kW - 1) * dW + 1; } oH = sH * (iH - 1) + ekH - 2 * pH; @@ -131,9 +131,9 @@ namespace nd4j { ekW = kW; } else { - ekD = kD + (kD - 1) * (dD - 1); - ekH = kH + (kH - 1) * (dH - 1); - ekW = kW + (kW - 1) * (dW - 1); + ekD = (kD - 1) * dD + 1; + ekH = (kH - 1) * dH + 1; + ekW = (kW - 1) * dW + 1; } oD = sD * (iD - 1) + ekD - 2 * pD; oH = sH * (iH - 1) + ekH - 2 * pH; @@ -194,53 +194,53 @@ namespace nd4j { } - static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) { + // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) { - if(kH != 1) { - if(isSameMode) { - pH = (oH - 1) * sH - iH + kH - pH; - dH = dH - 1; - } - else - dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); - } - if(kW != 1) { - if(isSameMode) { - pW = (oW - 1) * sW - iW + kW - pW; - dW = dW - 1; - } - else - dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); - } - } + // if(kH != 1) { + // if(isSameMode) { + // pH = (oH - 1) * sH - iH + kH - pH; + // dH = dH - 1; + // } + // else + // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); + // } + // if(kW != 1) { + // if(isSameMode) { + // pW = (oW - 1) * sW - iW + kW - pW; + // dW = dW - 1; + // } + // else + // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); + // } + // } - static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { + // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { - if(kD != 1) { - if(isSameMode) { - pD = (oD - 1) * sD - iD + kD - pD; - dD = dD - 1; - } - else - dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); - } - if(kH != 1) { - if(isSameMode) { - pH = (oH - 1) * sH - iH + kH - pH; - dH = dH - 1; - } - else - dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); - } - if(kW != 1) { - if(isSameMode) { - pW = (oW - 1) * sW - iW + kW - pW; - dW = dW - 1; - } - else - dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); - } - } + // if(kD != 1) { + // if(isSameMode) { + // pD = (oD - 1) * sD - iD + kD - pD; + // dD = dD - 1; + // } + // else + // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); + // } + // if(kH != 1) { + // if(isSameMode) { + // pH = (oH - 1) * sH - iH + kH - pH; + // dH = dH - 1; + // } + // else + // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); + // } + // if(kW != 1) { + // if(isSameMode) { + // pW = (oW - 1) * sW - iW + kW - pW; + // dW = dW - 1; + // } + // else + // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); + // } + // } static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp index a36330fbe..a64864b1b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp @@ -47,9 +47,12 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output, const bool inOutAreSame = x == z; + int posOfNonUnityDim; + bias.isCommonVector(posOfNonUnityDim); + const uint bS = output.sizeAt(0); // batch size - const Nd4jLong yStrideC = bias.stridesOf()[0]; - const Nd4jLong zStrideB = output.stridesOf()[0]; + const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim); + const Nd4jLong zStrideB = output.strideAt(0); if(output.rankOf() == 4) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp index f5c0fe71c..6e801b1fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp @@ -54,6 +54,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const const uint oW = output->sizeAt(2); auto func = PRAGMA_THREADS_FOR_2D { + for (uint b = start_x; b < stop_x; b += inc_x) { for (uint oh = start_y; oh < stop_y; oh += inc_y) { for (uint ow = 0; ow < oW; ++ow) { @@ -69,13 +70,17 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const const int iw = ow * sW - pW + kw * dW; if (iw < 0 || iw >= iW) continue; - const X val = x[shape::getOffset(xShapeInfo, {b, (uint) ih, (uint) iw, c})] + y[shape::getOffset(yShapeInfo, {kh, kw, c})]; + uint xCoords[4] = {b, (uint)ih, (uint)iw, c}; + uint yCoords[3] = {kh, kw, c}; + + const X val = x[shape::getOffset(xShapeInfo, xCoords)] + y[shape::getOffset(yShapeInfo, yCoords)]; if (val > max) max = val; } } - z[shape::getOffset(zShapeInfo, {b, oh, ow, c})] = static_cast(max); + uint zCoords[4] = {b, oh, ow, c}; + z[shape::getOffset(zShapeInfo, zCoords)] = static_cast(max); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index 6ea2992b9..d55e4f6dd 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -47,7 +47,7 @@ namespace helpers { if (zeroPointFromMin > quantMaxF) { return static_cast(quantMax); } - return nd4j::math::nd4j_round(zeroPointFromMin); + return (uint16_t)nd4j::math::nd4j_round(zeroPointFromMin); }(); // compute nudged min and max with computed nudged zero point *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp index 05c5eaf6e..d9b018268 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp @@ -13,16 +13,52 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ // // @author sgazeos@gmail.com // #include #include +#include namespace nd4j { namespace ops { namespace helpers { + typedef std::vector> ColorTable_t; + static ColorTable_t DefaultColorTable(int depth) { + std::vector> colorTable; + colorTable.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow + colorTable.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue + colorTable.emplace_back(std::vector({1, 0, 0, 1})); // 2: red + colorTable.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime + colorTable.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple + colorTable.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive + colorTable.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon + colorTable.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue + colorTable.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua + colorTable.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia + + if (depth == 1) { + for (Nd4jLong i = 0; i < colorTable.size(); i++) { + colorTable[i][0] = 1; + } + } + return colorTable; + } void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set @@ -33,44 +69,88 @@ namespace helpers { // colors - colors for each box given // set up color for each box as frame auto batchSize = images->sizeAt(0); + auto boxSize = boxes->sizeAt(0); auto height = images->sizeAt(1); auto width = images->sizeAt(2); auto channels = images->sizeAt(3); //auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch // auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch - auto colorSet = colors->allTensorsAlongDimension({1}); + //auto colorSet = colors->allTensorsAlongDimension({0}); output->assign(images); // fill up all output with input images, then fill up boxes - - PRAGMA_OMP_PARALLEL_FOR - for (auto b = 0; b < batchSize; ++b) { // loop by batch -// auto image = imageList->at(b); - - for (auto c = 0; c < colorSet->size(); ++c) { - // box with shape - auto internalBox = (*boxes)(b, {0})(c, {0});//internalBoxes->at(c); - auto color = colorSet->at(c); - auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox.e(0))); - auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox.e(2))); - auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox.e(1))); - auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox.e(3))); - for (auto y = rowStart; y <= rowEnd; y++) { - for (auto e = 0; e < color->lengthOf(); ++e) { - output->p(b, y, colStart, e, color->e(e)); - output->p(b, y, colEnd, e, color->e(e)); - } + ColorTable_t colorTable; + if (colors) { + for (auto i = 0; i < colors->sizeAt(0); i++) { + std::vector colorValue(4); + for (auto j = 0; j < 4; j++) { + colorValue[j] = j < colors->sizeAt(1) ? colors->e(i, j) : 1.f; } - for (auto x = colStart + 1; x < colEnd; x++) { - for (auto e = 0; e < color->lengthOf(); ++e) { - output->p(b, rowStart, x, e, color->e(e)); - output->p(b, rowEnd, x, e, color->e(e)); + colorTable.emplace_back(colorValue); + } + } + if (colorTable.empty()) + colorTable = DefaultColorTable(channels); + auto func = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; ++batch) { // loop by batch + const Nd4jLong numBoxes = boxes->sizeAt(1); + for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { + auto colorIndex = boxIndex % colorTable.size(); + auto rowStart = Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 0)); + auto rowStartBound = nd4j::math::nd4j_max(Nd4jLong(0), rowStart); + auto rowEnd = Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 2)); + auto rowEndBound = nd4j::math::nd4j_min(Nd4jLong(height - 1), rowEnd); + auto colStart = Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 1)); + auto colStartBound = nd4j::math::nd4j_max(Nd4jLong(0), colStart); + auto colEnd = Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 3)); + auto colEndBound = nd4j::math::nd4j_min(Nd4jLong(width - 1), colEnd); + + if (rowStart > rowEnd || colStart > colEnd) { + nd4j_debug( + "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted " + "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd); + continue; + } + if (rowStart >= height || rowEnd < 0 || colStart >= width || + colEnd < 0) { + nd4j_debug( + "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely " + "outside the image and not be drawn\n ", rowStart, colStart, rowEnd, colEnd); + continue; + } + + // Draw upper line + if (rowStart >= 0) { + for (auto j = colStartBound; j <= colEndBound; ++j) + for (auto c = 0; c < channels; c++) { + output->p(batch, rowStart, j, c, colorTable[colorIndex][c]); + } + } + // Draw bottom line. + if (rowEnd < height) { + for (auto j = colStartBound; j <= colEndBound; ++j) + for (auto c = 0; c < channels; c++) { + output->p(batch, rowEnd, j, c, colorTable[colorIndex][c]); + } + } + + // Draw left line. + if (colStart >= 0) { + for (auto i = rowStartBound; i <= rowEndBound; ++i) + for (auto c = 0; c < channels; c++) { + output->p(batch, i, colStart, c, colorTable[colorIndex][c]); + } + } + // Draw right line. + if (colEnd < width) { + for (auto i = rowStartBound; i <= rowEndBound; ++i) + for (auto c = 0; c < channels; c++) { + output->p(batch, i, colEnd, c, colorTable[colorIndex][c]); + } } } } -// delete internalBoxes; - } - delete colorSet; -// delete imageList; -// delete boxList; + }; + samediff::Threads::parallel_tad(func, 0, batchSize); + } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 11bc1ecaa..0f69ef0fb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -13,6 +14,20 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ // // @author sgazeos@gmail.com @@ -32,6 +47,91 @@ namespace helpers { // https://en.wikipedia.org/wiki/Bilinear_interpolation) double interpolarValue; }; + // calculateResizeScale determines the float scaling factor. + inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); + } + + struct ImageResizerState { + explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) + : _alignCorners(alignCorners), + _halfPixelCenters(halfPixelCenters) {} + + // ValidateAndCalculateOutputSize checks the bounds on the input tensors + // and requested size, sets up some of the resizing state such as the + // heightScale and widthScale, and calculates the output size. + // If any of these operations fails, it sets an error status in + // the context, which the caller must check. + int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) { + // + batchSize = input->sizeAt(0);//.dim_size(0); + outHeight = height; + outWidth = width; //internal::SubtleMustCopy(Svec(1)); + inHeight = static_cast(input->sizeAt(1)); + inWidth = static_cast(input->sizeAt(2)); + channels = input->sizeAt(3); //.dim_size(3); + heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); + widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); + + // Guard against overflows + if (ceilf((outHeight - 1) * heightScale) > static_cast(DataTypeUtils::max())) { + nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); + return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height"); + } + if (ceilf((outWidth - 1) * heightScale) > static_cast(DataTypeUtils::max())) { + nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); + return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width"); + } + + return Status::OK(); + } + + // Calculates all the required variables, and allocates the output. + int validateAndCreateOutput(NDArray const* input, int const width, int const height) { + return validateAndCalculateOutputSize(input, width, height); + } + + Nd4jLong batchSize; + Nd4jLong outHeight; + Nd4jLong outWidth; + Nd4jLong inHeight; + Nd4jLong inWidth; + Nd4jLong channels; + float heightScale; + float widthScale; + NDArray* output = nullptr; + + private: + bool _alignCorners; + bool _halfPixelCenters; + }; + + // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +// floating point coordinates of the top,left pixel is 0.5,0.5. + struct HalfPixelScaler { + HalfPixelScaler(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } + }; + + struct WeightsAndIndices { + float _weight0; + float _weight1; + float _weight2; + float _weight3; + Nd4jLong _index0; + Nd4jLong _index1; + Nd4jLong _index2; + Nd4jLong _index3; + + int _advance; // advance value. + }; inline void computeInterpolationWeights(Nd4jLong outSize, Nd4jLong inSize, @@ -40,13 +140,16 @@ namespace helpers { interpolationData[outSize].bottomIndex = 0; interpolationData[outSize].topIndex = 0; - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong i = outSize - 1; i >= 0; --i) { - double in = i * scale; - interpolationData[i].bottomIndex = static_cast(in); - interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); - interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; - } + auto func = PRAGMA_THREADS_FOR { + for (auto k = start; k < stop; k++) { + auto i = (outSize - k - 1); + double in = i * scale; + interpolationData[i].bottomIndex = static_cast(in); + interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); + interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; + } + }; + samediff::Threads::parallel_for(func, 0, outSize); } /** @@ -72,10 +175,10 @@ namespace helpers { Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong outRowSize = outWidth * channels; - T const *input_b_ptr = reinterpret_cast(images->getBuffer()); // this works only with 'c' direction + T const *pInput = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction BilinearInterpolationData const *xs_ = xs.data(); - T *output_y_ptr = reinterpret_cast(output->buffer()); + T* pOutput = output->dataBuffer()->primaryAsT(); auto computeBilinear = [](double topLeft, double topRight, double bottomLeft, double bottomRight, double xVal, double yVal) { @@ -84,30 +187,32 @@ namespace helpers { return top + (bottom - top) * yVal; }; - // FIXME: fix parallelism here - for (Nd4jLong b = 0; b < batchSize; ++b) { - for (Nd4jLong y = 0; y < outHeight; ++y) { - const T *ys_input_lower_ptr = input_b_ptr + ys[y].bottomIndex * inRowSize; - const T *ys_input_upper_ptr = input_b_ptr + ys[y].topIndex * inRowSize; + auto func = PRAGMA_THREADS_FOR { + for (auto b = start; b < stop; ++b) { + for (auto y = 0; y < outHeight; ++y) { + const T *ys_input_lower_ptr = pInput + ys[y].bottomIndex * inRowSize; + const T *ys_input_upper_ptr = pInput + ys[y].topIndex * inRowSize; double yVal = ys[y].interpolarValue; - for (Nd4jLong x = 0; x < outWidth; ++x) { + for (auto x = 0; x < outWidth; ++x) { auto xsBottom = xs_[x].bottomIndex; auto xsTop = xs_[x].topIndex; auto xVal = xs_[x].interpolarValue; - for (int c = 0; c < channels; ++c) { + for (auto c = 0; c < channels; ++c) { double topLeft(ys_input_lower_ptr[xsBottom + c]); double topRight(ys_input_lower_ptr[xsTop + c]); double bottomLeft(ys_input_upper_ptr[xsBottom + c]); double bottomRight(ys_input_upper_ptr[xsTop + c]); - output_y_ptr[x * channels + c] = + pOutput[x * channels + c] = computeBilinear(topLeft, topRight, bottomLeft, bottomRight, xVal, yVal); } } - output_y_ptr += outRowSize; + pOutput += outRowSize; } - input_b_ptr += inBatchNumValues; + pInput += inBatchNumValues; } + }; + samediff::Threads::parallel_tad(func, 0, batchSize); } template @@ -175,7 +280,7 @@ namespace helpers { // Handle no-op resizes efficiently. if (outHeight == inHeight && outWidth == inWidth) { output->assign(images); - return ND4J_STATUS_OK; + return Status::OK(); } if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || @@ -192,9 +297,9 @@ namespace helpers { for (auto y = start_y; y < stop_y; y += inc_y) { Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast(nd4j::math::p_round(y * heightScale)) : static_cast(nd4j::math::p_floor(y * heightScale)), inHeight - 1); - for (int x = 0; x < outWidth; ++x) { + for (auto x = 0; x < outWidth; ++x) { Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast(nd4j::math::p_round(x * widthScale)) : static_cast(nd4j::math::p_floor(x * widthScale)),inWidth - 1); - for (Nd4jLong e = 0; e < channels; e++) + for (auto e = 0; e < channels; e++) output->p(b, y, x, e, images->e(b, inY, inX, e)); } } @@ -215,28 +320,16 @@ namespace helpers { LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void resizeImage_, - (NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const& xs, - std::vector const& ys, - NDArray* output), LIBND4J_TYPES); - int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, (images, width, height, center, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, - (NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES); - int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (images, width, height, center, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, - (NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES); template static void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices, @@ -250,7 +343,7 @@ namespace helpers { const int cropWidth = crops->sizeAt(2); const int depth = crops->sizeAt(3); - for (int b = 0; b < numBoxes; ++b) { + for (auto b = 0; b < numBoxes; ++b) { T y1 = boxes->t(b, 0); T x1 = boxes->t(b, 1); T y2 = boxes->t(b, 2); @@ -265,14 +358,14 @@ namespace helpers { T widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0); auto func = PRAGMA_THREADS_FOR { - for (int y = start; y < stop; y += increment) { + for (auto y = start; y < stop; y += increment) { const float inY = (cropHeight > 1) ? y1 * (imageHeight - 1) + y * heightScale : 0.5 * (y1 + y2) * (imageHeight - 1); if (inY < 0 || inY > imageHeight - 1) { - for (int x = 0; x < cropWidth; ++x) { - for (int d = 0; d < depth; ++d) { + for (auto x = 0; x < cropWidth; ++x) { + for (auto d = 0; d < depth; ++d) { crops->p(b, y, x, d, extrapolationVal); } } @@ -283,13 +376,13 @@ namespace helpers { const int bottomYIndex = nd4j::math::p_ceil(inY); const float y_lerp = inY - topYIndex; - for (int x = 0; x < cropWidth; ++x) { + for (auto x = 0; x < cropWidth; ++x) { const float in_x = (cropWidth > 1) ? x1 * (imageWidth - 1) + x * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1); if (in_x < 0 || in_x > imageWidth - 1) { - for (int d = 0; d < depth; ++d) { + for (auto d = 0; d < depth; ++d) { crops->p(b, y, x, d, extrapolationVal); } continue; @@ -298,7 +391,7 @@ namespace helpers { int right_x_index = math::p_ceil(in_x); T x_lerp = in_x - left_x_index; - for (int d = 0; d < depth; ++d) { + for (auto d = 0; d < depth; ++d) { const float topLeft(images->e(bIn, topYIndex, left_x_index, d)); const float topRight(images->e(bIn, topYIndex, right_x_index, d)); const float bottomLeft(images->e(bIn, bottomYIndex, left_x_index, d)); @@ -309,21 +402,21 @@ namespace helpers { } } } else { // method is "nearest neighbor" - for (int x = 0; x < cropWidth; ++x) { + for (auto x = 0; x < cropWidth; ++x) { const float inX = (cropWidth > 1) ? x1 * (imageWidth - 1) + x * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1); if (inX < 0 || inX > imageWidth - 1) { - for (int d = 0; d < depth; ++d) { + for (auto d = 0; d < depth; ++d) { crops->p(b, y, x, d, extrapolationVal); } continue; } const int closestXIndex = roundf(inX); const int closestYIndex = roundf(inY); - for (int d = 0; d < depth; ++d) { - crops->p(b, y, x, d, (F) images->e(bIn, closestYIndex, closestXIndex, d)); + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, images->e(bIn, closestYIndex, closestXIndex, d)); } } } @@ -333,10 +426,473 @@ namespace helpers { samediff::Threads::parallel_for(func, 0, cropHeight); } } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// ------------------------------------------------------------------------------------------------------------------ // +// Bicubic interpolation +// ------------------------------------------------------------------------------------------------------------------ // + class CachedInterpolationCalculator { + public: + CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. + inline int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, + const Nd4jLong x3) { + // We use 2 hands and walk through, copying from one to another where + // we already have values. + // Invariant, new_indicies_hand <= cached_values_hand + const Nd4jLong new_x_indices[] = {x0, x1, x2, x3}; + int cachedValuesHand = 0; + int newIndiciesHand = 0; + while (cachedValuesHand < 4) { + if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { + if (newIndiciesHand < cachedValuesHand) { + _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; + } + newIndiciesHand++; + } + cachedValuesHand++; + } + switch (newIndiciesHand) { + case 0: + _indexes[0] = x0; + case 1: + _indexes[1] = x1; + case 2: + _indexes[2] = x2; + case 3: + _indexes[3] = x3; + break; + } + return newIndiciesHand; + } + private: + Nd4jLong _indexes[4]; + }; + static const Nd4jLong kTableSize = 1024LL; //(1 << 10); + + const float* initCoeffsTable(const double a) { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float* coeffs_table = new float[(kTableSize + 1) * 2]; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i <= stop; ++i) { + float x = i * 1.0 / kTableSize; + coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; + x += 1.0; + coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; + } + }; + samediff::Threads::parallel_for(func, 0, kTableSize); + return coeffs_table; + } + + const float* getCoeffsTable(const bool use_keys_cubic) { + // Static so that we initialize it on first use + if (use_keys_cubic) { + // http://ieeexplore.ieee.org/document/1163711/ + // R. G. Keys. Cubic convolution interpolation for digital image + // processing. IEEE Transactions on Acoustics, Speech, and Signal + // Processing, 29(6):1153–1160, 1981. + static const float* coeffs_table = initCoeffsTable(-0.5f); + return coeffs_table; + } else { + static const float* coeffs_table = initCoeffsTable(-0.75f); + return coeffs_table; + } + } + + inline Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { + return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); + } + + template + int resizeBicubicFunctor_(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool preserveAspectRatio, bool antialias, NDArray* output) { + return ND4J_STATUS_OK; + } + + int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool preserveAspectRatio, bool antialias, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image, + width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES); + } +// ------------------------------------------------------------------------------------------------------------------ // + + template + inline float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3, + const T value0, const T value1, const T value2, const T value3) { + return static_cast(value0) * weight0 + + static_cast(value1) * weight1 + + static_cast(value2) * weight2 + + static_cast(value3) * weight3; + } + +// Compute the 1D interpolation for a given X index using the y_weights + static float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) { + return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]); + } + + template + inline void getWeightsAndIndices(const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) { + const Scaler scaler; + const float in_loc_f = scaler(out_loc, scale); + const Nd4jLong in_loc = std::floor(in_loc_f); + const float delta = in_loc_f - in_loc; + const Nd4jLong offset = lrintf(delta * kTableSize); + const float* coeffs_table = getCoeffsTable(use_keys_cubic); + if (use_keys_cubic) { + // The legacy code placed more weight on the edge pixels, since bounding + // the set of inputs to sample could cause an edge pixel to be repeated. + // Here we change the behavior at borders to match that used by the + // scale_and_translate_op, where sampling locations outside the image have + // their weight set to 0, and the weights are renormalized so that their sum + // is 1.0. + out->_index0 = bound(in_loc - 1, limit); + out->_weight0 = + (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); + out->_index1 = bound(in_loc, limit); + out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); + out->_index2 = bound(in_loc + 1, limit); + out->_weight2 = + (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] + : 0.0f); + out->_index3 = bound(in_loc + 2, limit); + out->_weight3 = (out->_index3 == in_loc + 2 + ? coeffs_table[(kTableSize - offset) * 2 + 1] + : 0.0f); + + const float weight_sum = + out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; + if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits::min()) { + const float one_over_weight_sum = 1.0f / weight_sum; + out->_weight0 *= one_over_weight_sum; + out->_weight1 *= one_over_weight_sum; + out->_weight2 *= one_over_weight_sum; + out->_weight3 *= one_over_weight_sum; + } + } else { + out->_weight0 = coeffs_table[offset * 2 + 1]; + out->_weight1 = coeffs_table[offset * 2]; + out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; + out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->_index0 = bound(in_loc - 1, limit); + out->_index1 = bound(in_loc, limit); + out->_index2 = bound(in_loc + 1, limit); + out->_index3 = bound(in_loc + 2, limit); + } + } + +// Older incorrect scaling method that causes all resizes to have a slight +// translation leading to inconsistent results. For example, a flip then a +// resize gives different results then a resize then a flip. + struct LegacyScaler { + LegacyScaler(){}; + inline float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } + }; + + static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, + const bool half_pixel_centers, + std::vector* x_wais) { + CachedInterpolationCalculator calc; + if (half_pixel_centers) { + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + getWeightsAndIndices( + resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); + auto &x_wai = (*x_wais)[x]; + x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2, + x_wai._index3); + } + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + getWeightsAndIndices( + resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2, + x_wai._index3); + } + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } + // Scale the values so they can be used as offsets into buffers. + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + (*x_wais)[x]._index0 *= resizer_state.channels; + (*x_wais)[x]._index1 *= resizer_state.channels; + (*x_wais)[x]._index2 *= resizer_state.channels; + (*x_wais)[x]._index3 *= resizer_state.channels; + } + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } + + template + static FORCEINLINE float computeYInterpolation( + int which, int channelNum, const WeightsAndIndices& yWai, + const T* pY0, const T* pY1, const T* pY2, const T* pY3, + const WeightsAndIndices& xWai) { + int xIndex; + switch (which) { + case 0: + xIndex = xWai._index0; + break; + case 1: + xIndex = xWai._index1; + break; + case 2: + xIndex = xWai._index2; + break; + default: + xIndex = xWai._index3; + break; + } + const Nd4jLong pt_index = xIndex + channelNum; + return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, + yWai._weight3, pY0[pt_index], pY1[pt_index], + pY2[pt_index], pY3[pt_index]); + } + + template + static void + bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { + std::vector xWais(resizerState.outWidth); + + computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); + + const auto numChannels = resizerState.channels; + const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; + const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + + const T* inputPtr = image->getDataBuffer()->primaryAsT(); + T* pOutputY = output->dataBuffer()->primaryAsT(); //_data.data(); + std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto b = start; b < stop; ++b) { + auto pInput = inputPtr + b * inBatchWidth; + + for (auto y = 0; y < resizerState.outHeight; ++y) { + auto pOutput = &pOutputY[(b * resizerState.outHeight + y) * resizerState.outWidth * numChannels]; + + WeightsAndIndices yWai; + if (halfPixelCenters) { + getWeightsAndIndices( + resizerState.heightScale, y, resizerState.inHeight, &yWai); + } else { + getWeightsAndIndices( + resizerState.heightScale, y, resizerState.inHeight, &yWai); + } + // Make pointers represent offsets of data in inputBPtr. + const T *y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T *y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T *y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T *y_ptr_3 = pInput + yWai._index3 * inRowWidth; + + if (numChannels == 3) { + // Manually unroll case of 3 channels. + float cached_value_0[4] = {0}; + float cached_value_1[4] = {0}; + float cached_value_2[4] = {0}; + for (auto x = 0; x < resizerState.outWidth; ++x) { + const WeightsAndIndices &xWai = xWais[x]; + // Shift values in cached_value_* to fill first '_advance' values. + switch (xWai._advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } + + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + cached_value_0[0] = computeYInterpolation( + 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[0] = computeYInterpolation( + 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[0] = computeYInterpolation( + 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 1: + cached_value_0[1] = computeYInterpolation( + 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[1] = computeYInterpolation( + 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[1] = computeYInterpolation( + 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 2: + cached_value_0[2] = computeYInterpolation( + 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[2] = computeYInterpolation( + 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[2] = computeYInterpolation( + 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 3: + cached_value_0[3] = computeYInterpolation( + 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[3] = computeYInterpolation( + 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[3] = computeYInterpolation( + 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + break; + } + pOutput[x * numChannels + 0] = + compute(cached_value_0, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * numChannels + 1] = + compute(cached_value_1, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * numChannels + 2] = + compute(cached_value_2, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } else { + for (auto x = 0; x < resizerState.outWidth; ++x) { + const WeightsAndIndices &xWai = xWais[x]; + // Shift values in cachedValue to fill first '_advance' values. + switch (xWai._advance) { + case 3: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; + } + break; + case 2: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; + } + break; + case 1: { + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; + } + break; + } + } + + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = computeYInterpolation( + 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + + case 1: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 1] = computeYInterpolation( + 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + + case 2: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 2] = computeYInterpolation( + 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + + case 3: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 3] = computeYInterpolation( + 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + break; + } + for (auto c = 0; c < numChannels; ++c) { + pOutput[x * numChannels + c] = + compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } + } + } + } + }; + samediff::Threads::parallel_tad(func, 0, resizerState.batchSize); + } + +// simplified bicubic resize without antialiasing +// + template + int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool const alignCorners, bool const halfPixelAlign, NDArray* output) { + ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align + int res = st.validateAndCreateOutput(image, width, height); + if (res == Status::OK()) + bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); + + return res; + } + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool const alignCorners, bool const halfPixelAlign, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, + image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); + } +// ------------------------------------------------------------------------------------------------------------------ // + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { + switch (method) { + case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break; + case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break; + case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; + case kResizeLanczos5: + case kResizeGaussian: + case kResizeArea: + case kResizeMitchelcubic: + throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); + } + return ND4J_STATUS_OK; + } + +// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ // +// crop and resize helper functor: +// \@param context - launch context for operation +// \@param images - batch of images (4D tensor) with shape {batch, width, height, channels} with given type +// \@param boxes - float boxes for crop +// \@param indices - integer boxes indices for crop +// \@param cropSize - integer size (newWidth, newHeight) +// \@param method - one of bilinear (0) or nearest neighbour (1) interpolation algorithm +// \@param extrapolationVal - radix to increase/decrease image +// \@param crops - output image batch (4D with given type) +// void - cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, + cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, (images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu index 7134d764a..1c0bc1925 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu @@ -44,7 +44,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo, const Y* y = reinterpret_cast(vy); X* z = reinterpret_cast(vz); - __shared__ int rank, channelPosition; + __shared__ int rank, channelPosition, posOfNonUnityDim; __shared__ Nd4jLong *sharedMem, len; __shared__ bool xzSameOffsets, xzAreSame; @@ -58,6 +58,8 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo, len = shape::length(xShapeInfo); channelPosition = isNCHW ? 1 : rank - 1; // second or last xzAreSame = x == z; + + shape::isCommonVector(yShapeInfo, posOfNonUnityDim); } __syncthreads(); @@ -69,7 +71,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo, const auto xOffsets = shape::getOffset(xShapeInfo, coords); const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords); - const auto yOffsets = shape::getOffset(yShapeInfo, coords + channelPosition); + const auto yOffsets = coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim]; if(xzAreSame) z[zOffsets] += static_cast(y[yOffsets]); @@ -94,7 +96,7 @@ void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& b PointersManager manager(block.launchContext(), "addBias"); - const int threadsPerBlock = MAX_NUM_THREADS; + const int threadsPerBlock = MAX_NUM_THREADS/2; const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu index dc1935b83..f924fdb75 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu @@ -34,64 +34,59 @@ static __global__ void col2imCuda(const void* columns, const Nd4jLong* colShapeI const T* col = reinterpret_cast(columns); T* im = reinterpret_cast(image); - __shared__ int colRank, imRank, kHeff, kWeff, oH, oW; - __shared__ Nd4jLong *sharedMem, imLen; + __shared__ uint kH, kW, oH, oW, *sharedMem; + __shared__ Nd4jLong imLen; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + sharedMem = reinterpret_cast(shmem); + + kH = dH * (colShapeInfo[3] - 1) + 1; + kW = dW * (colShapeInfo[4] - 1) + 1; oH = colShapeInfo[5]; oW = colShapeInfo[6]; - kHeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dH - 1); - kWeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dW - 1); - - imRank = 4; - colRank = 6; - - imLen = shape::length(imShapeInfo); + imLen = shape::length(imShapeInfo); } __syncthreads(); - const auto imInd = threadIdx.x + blockIdx.x * blockDim.x; + auto coords = sharedMem + threadIdx.x * 6; - if(imInd >= imLen) - return; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * colRank; + for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) { - shape::index2coords(imInd, imShapeInfo, coords); + shape::index2coords(i, imShapeInfo, coords); - const auto imOffset = shape::getOffset(imShapeInfo, coords); + const auto imOffset = shape::getOffset(imShapeInfo, coords); - const int imH = coords[2] + pH; - const int imW = coords[3] + pW; + const auto bSiCoffset = coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8]; - const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1; - const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1; + const uint imH = coords[2] + pH; + const uint imW = coords[3] + pW; - const int colHend = nd4j::math::nd4j_min(imH / sH + 1, oH); - const int colWend = nd4j::math::nd4j_min(imW / sW + 1, oW); + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - T val = 0; + const uint colHend = nd4j::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = nd4j::math::nd4j_min(imW / sW + 1, oW); - for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { - coords[2] = imH - coords[4] * sH; + T val = 0; - for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { - coords[3] = imW - coords[5] * sW; + for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { + coords[2] = imH - coords[4] * sH; + if(coords[2] % dH != 0) continue; - if(coords[2] % dH == 0 && coords[3] % dW == 0) { - coords[2] /= dH; - coords[3] /= dW; + for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { + coords[3] = imW - coords[5] * sW; + if(coords[3] % dW != 0) continue; - val += col[shape::getOffset(colShapeInfo, coords)]; + val += col[bSiCoffset + (coords[2]/dH)*colShapeInfo[9] + (coords[3]/dW)*colShapeInfo[10] + coords[4]*colShapeInfo[11] + coords[5]*colShapeInfo[12]]; } } + im[imOffset] = val; } - - im[imOffset] = val; } //////////////////////////////////////////////////////////////////////// @@ -184,8 +179,8 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc void* image, const Nd4jLong* imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); - //col2imCuda<<>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); + // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); + col2imCuda<<>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); } ////////////////////////////////////////////////////////////////////////// @@ -195,7 +190,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const const int threadsPerBlock = MAX_NUM_THREADS / 2; const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; NDArray::prepareSpecialUse({&im}, {&col}); BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 273749bfd..4887b7266 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -121,74 +122,71 @@ static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShape const T* col = reinterpret_cast(columns); T* vol = reinterpret_cast(volume); - __shared__ int colRank, volRank, kDeff, kHeff, kWeff, oD, oH, oW; - __shared__ Nd4jLong *sharedMem, volLen; + __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; + __shared__ Nd4jLong volLen; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + sharedMem = reinterpret_cast(shmem); oD = colShapeInfo[6]; oH = colShapeInfo[7]; oW = colShapeInfo[8]; - kDeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dD - 1); - kHeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dH - 1); - kWeff = colShapeInfo[5] + (colShapeInfo[5] - 1) * (dW - 1); + kD = dD * (colShapeInfo[3] - 1) + 1; + kH = dH * (colShapeInfo[4] - 1) + 1; + kW = dW * (colShapeInfo[5] - 1) + 1; - volRank = 5; - colRank = 8; - - volLen = shape::length(volShapeInfo); + volLen = shape::length(volShapeInfo); } __syncthreads(); - const auto volInd = threadIdx.x + blockIdx.x * blockDim.x; + auto coords = sharedMem + threadIdx.x * 8; - if(volInd >= volLen) - return; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * colRank; + for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { - shape::index2coords(volInd, volShapeInfo, coords); + shape::index2coords(i, volShapeInfo, coords); - const auto volOffset = shape::getOffset(volShapeInfo, coords); + const auto volOffset = shape::getOffset(volShapeInfo, coords); - const int imD = coords[2] + pD; - const int imH = coords[3] + pH; - const int imW = coords[4] + pW; + const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; - const int colDstart = (imD < kDeff) ? 0 : (imD - kDeff) / sD + 1; - const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1; - const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1; + const uint imD = coords[2] + pD; + const uint imH = coords[3] + pH; + const uint imW = coords[4] + pW; - const int colDend = nd4j::math::nd4j_min(imD / sD + 1, oD); - const int colHend = nd4j::math::nd4j_min(imH / sH + 1, oH); - const int colWend = nd4j::math::nd4j_min(imW / sW + 1, oW); + const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - T val = 0; + const uint colDend = nd4j::math::nd4j_min(imD / sD + 1, oD); + const uint colHend = nd4j::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = nd4j::math::nd4j_min(imW / sW + 1, oW); - for(coords[5] = colDstart; coords[5] < colDend; ++coords[5]) { - coords[2] = imD - coords[5] * sD; + T val = 0; - for(coords[6] = colHstart; coords[6] < colHend; ++coords[6]) { - coords[3] = imH - coords[6] * sH; + for(uint colD = colDstart; colD < colDend; ++colD) { + coords[2] = imD - colD * sD; + if(coords[2] % dD != 0) continue; - for(coords[7] = colWstart; coords[7] < colWend; ++coords[7]) { - coords[4] = imW - coords[7] * sW; + for(uint colH = colHstart; colH < colHend; ++colH) { + coords[3] = imH - colH * sH; + if(coords[3] % dH != 0) continue; - if(coords[2] % dD == 0 && coords[3] % dH == 0 && coords[4] % dW == 0) { - coords[2] /= dD; - coords[3] /= dH; - coords[4] /= dW; + for(uint colW = colWstart; colW < colWend; ++colW) { + coords[4] = imW - colW * sW; + if(coords[4] % dW != 0) continue; + + val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]]; - val += col[shape::getOffset(colShapeInfo, coords)]; } } } - } - vol[volOffset] = val; + vol[volOffset] = val; + } } ////////////////////////////////////////////////////////////////////////// @@ -208,7 +206,7 @@ void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col, const int threadsPerBlock = MAX_NUM_THREADS / 4; const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; NDArray::prepareSpecialUse({&vol}, {&col}); BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu index eb632e70c..9b05f891a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu @@ -24,39 +24,112 @@ namespace nd4j { namespace ops { namespace helpers { - template - static __global__ void drawBoundingBoxesKernel(T const* images, Nd4jLong* imagesShape, T const* boxes, - Nd4jLong* boxesShape, T const* colors, Nd4jLong* colorsShape, T* output, Nd4jLong* outputShape, - Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong colorSetSize) { + typedef NDArray ColorTable_t; + static NDArray DefaultColorTable(int depth) { + //std::vector> colorTable; + const Nd4jLong kDefaultTableLength = 10; + const Nd4jLong kDefaultChannelLength = 4; + NDArray colorTable('c', {kDefaultTableLength, kDefaultChannelLength}, { + 1,1,0,1, // yellow + 0, 0, 1, 1, // 1: blue + 1, 0, 0, 1, // 2: red + 0, 1, 0, 1, // 3: lime + 0.5, 0, 0.5, 1, // 4: purple + 0.5, 0.5, 0, 1, // 5: olive + 0.5, 0, 0, 1, // 6: maroon + 0, 0, 0.5, 1, // 7: navy blue + 0, 1, 1, 1, // 8: aqua + 1, 0, 1, 1 // 9: fuchsia + }, DataType::FLOAT32); - for (auto b = blockIdx.x; b < (int)batchSize; b += gridDim.x) { // loop by batch - for (auto c = 0; c < colorSetSize; c++) { + if (depth == 1) { + colorTable.assign(1.f); // all to white when black and white colors + } + return colorTable; + } + + template + static __global__ void drawBoundingBoxesKernel(T const* images, Nd4jLong* imagesShape, float const* boxes, + Nd4jLong* boxesShape, float const* colorTable, Nd4jLong* colorTableShape, T* output, Nd4jLong* outputShape, + Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong boxSize, Nd4jLong colorTableLen) { + + for (auto batch = blockIdx.x; batch < (int)batchSize; batch += gridDim.x) { // loop by batch + for (auto boxIndex = 0; boxIndex < boxSize; ++boxIndex) { // box with shape - auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c); - auto color = &colors[channels * c];//colorSet->at(c); - auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0])); - auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2])); - auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1])); - auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3])); - for (auto y = rowStart + threadIdx.x; y <= rowEnd; y += blockDim.x) { - for (auto e = 0; e < channels; ++e) { - Nd4jLong yMinPos[] = {b, y, colStart, e}; - Nd4jLong yMaxPos[] = {b, y, colEnd, e}; - auto zIndexYmin = shape::getOffset(outputShape, yMinPos); - auto zIndexYmax = shape::getOffset(outputShape, yMaxPos); - output[zIndexYmin] = color[e]; - output[zIndexYmax] = color[e]; - } + //auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c); + auto colorIndex = boxIndex % colorTableLen;//colorSet->at(c); +// auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0])); +// auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2])); +// auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1])); +// auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3])); + Nd4jLong indices0[] = {batch, boxIndex, 0}; + Nd4jLong indices1[] = {batch, boxIndex, 1}; + Nd4jLong indices2[] = {batch, boxIndex, 2}; + Nd4jLong indices3[] = {batch, boxIndex, 3}; + auto rowStart = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices0, 0)]); + auto rowStartBound = nd4j::math::nd4j_max(Nd4jLong (0), rowStart); + auto rowEnd = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices2, 0)]); + auto rowEndBound = nd4j::math::nd4j_min(Nd4jLong (height - 1), rowEnd); + auto colStart = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices1, 0)]); + auto colStartBound = nd4j::math::nd4j_max(Nd4jLong (0), colStart); + auto colEnd = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices3, 0)]); + auto colEndBound = nd4j::math::nd4j_min(Nd4jLong(width - 1), colEnd); + if (rowStart > rowEnd || colStart > colEnd) { +// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted " +// "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd); + continue; } - for (auto x = colStart + 1 + threadIdx.x; x < colEnd; x += blockDim.x) { - for (auto e = 0; e < channels; ++e) { - Nd4jLong xMinPos[] = {b, rowStart, x, e}; - Nd4jLong xMaxPos[] = {b, rowEnd, x, e}; - auto zIndexXmin = shape::getOffset(outputShape, xMinPos); - auto zIndexXmax = shape::getOffset(outputShape, xMaxPos); - output[zIndexXmin] = color[e]; - output[zIndexXmax] = color[e]; - } + if (rowStart >= height || rowEnd < 0 || colStart >= width || + colEnd < 0) { +// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely " +// "outside the image and not be drawn\n", rowStart, colStart, rowEnd, colEnd); + continue; + } + + // Draw upper line + if (rowStart >= 0) { + for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, rowStart, j, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + // Draw bottom line. + if (rowEnd < height) { + for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, rowEnd, j, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + + // Draw left line. + if (colStart >= 0) { + for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, i, colStart, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + // Draw right line. + if (colEnd < width) { + for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, i, colEnd, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } } } } @@ -70,15 +143,19 @@ namespace helpers { auto width = images->sizeAt(2); auto channels = images->sizeAt(3); auto stream = context->getCudaStream(); - auto colorSetSize = colors->sizeAt(0); + auto boxSize = boxes->sizeAt(1); + NDArray colorsTable = DefaultColorTable(channels); + if ((colors != nullptr && colors->lengthOf() > 0)) { + colorsTable = *colors; + } auto imagesBuf = images->getDataBuffer()->specialAsT(); - auto boxesBuf = boxes->getDataBuffer()->specialAsT(); - auto colorsBuf = colors->getDataBuffer()->specialAsT(); + auto boxesBuf = boxes->getDataBuffer()->specialAsT(); // boxes should be float32 + auto colorsTableBuf = colorsTable.getDataBuffer()->specialAsT(); // color table is float32 auto outputBuf = output->dataBuffer()->specialAsT(); - drawBoundingBoxesKernel<< 128? 128: batchSize, 256, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(), - boxesBuf, boxes->getSpecialShapeInfo(), colorsBuf, colors->getSpecialShapeInfo(), - outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, colorSetSize); + drawBoundingBoxesKernel<<<128, 128, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(), + boxesBuf, boxes->getSpecialShapeInfo(), colorsTableBuf, colorsTable.getSpecialShapeInfo(), + outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, boxSize, colorsTable.lengthOf()); } void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 5e57a9a90..0042877a9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -294,15 +294,638 @@ namespace helpers { int width, int height, bool center, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Bicubic interpolation +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Utility functions and classes + + // calculateResizeScale determines the float scaling factor. + inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); + } + + struct ImageResizerState { + explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) + : _alignCorners(alignCorners), + _halfPixelCenters(halfPixelCenters) {} + + // ValidateAndCalculateOutputSize checks the bounds on the input tensors + // and requested size, sets up some of the resizing state such as the + // heightScale and widthScale, and calculates the output size. + // If any of these operations fails, it sets an error status in + // the context, which the caller must check. + int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) { + // + batchSize = input->sizeAt(0);//.dim_size(0); + outHeight = height; + outWidth = width; //internal::SubtleMustCopy(Svec(1)); + inHeight = static_cast(input->sizeAt(1)); + inWidth = static_cast(input->sizeAt(2)); + channels = input->sizeAt(3); //.dim_size(3); + heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); + widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); + + // Guard against overflows + if (ceilf((outHeight - 1) * heightScale) > static_cast(DataTypeUtils::max())) { + nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); + return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height"); + } + if (ceilf((outWidth - 1) * heightScale) > static_cast(DataTypeUtils::max())) { + nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); + return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width"); + } + + return Status::OK(); + } + + // Calculates all the required variables, and allocates the output. + int validateAndCreateOutput(NDArray const* input, int const width, int const height) { + return validateAndCalculateOutputSize(input, width, height); + } + + Nd4jLong batchSize; + Nd4jLong outHeight; + Nd4jLong outWidth; + Nd4jLong inHeight; + Nd4jLong inWidth; + Nd4jLong channels; + float heightScale; + float widthScale; + NDArray* output = nullptr; + cudaStream_t* stream; + private: + bool _alignCorners; + bool _halfPixelCenters; + }; + + // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +// floating point coordinates of the top,left pixel is 0.5,0.5. + struct HalfPixelScaler { + _CUDA_HD HalfPixelScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } + }; + + struct WeightsAndIndices { + float _weight0; + float _weight1; + float _weight2; + float _weight3; + Nd4jLong _index0; + Nd4jLong _index1; + Nd4jLong _index2; + Nd4jLong _index3; + + int _advance; // advance value. + }; + + class CachedInterpolationCalculator { + public: + _CUDA_HD CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} + + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. + inline _CUDA_HD int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, + const Nd4jLong x3) { + // We use 2 hands and walk through, copying from one to another where + // we already have values. + // Invariant, new_indicies_hand <= cached_values_hand + const Nd4jLong new_x_indices[4] = {x0, x1, x2, x3}; + int cachedValuesHand = 0; + int newIndiciesHand = 0; + while (cachedValuesHand < 4) { + if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { + if (newIndiciesHand < cachedValuesHand) { + _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; + } + newIndiciesHand++; + } + cachedValuesHand++; + } + switch (newIndiciesHand) { + case 0: + _indexes[0] = x0; + case 1: + _indexes[1] = x1; + case 2: + _indexes[2] = x2; + case 3: + _indexes[3] = x3; + break; + } + return newIndiciesHand; + } + + private: + Nd4jLong _indexes[4]; + }; + + + static __global__ void initCoefTableKernel(const double a, float* table, Nd4jLong tableSize) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (int i = start; i <= tableSize; i += step) { + float x = i * 1.0 / tableSize; + table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; + x += 1.0; + table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; + } + } + + static const Nd4jLong kTableSize = (1 << 10); + float* initCoeffsTable(const double a, cudaStream_t* stream) { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float* coeffs_table; // = new float[(kTableSize + 1) * 2]; + auto err = cudaMalloc(&coeffs_table, sizeof(float) * ((kTableSize + 1) * 2)); + if (err != 0) { + throw cuda_exception::build("helpers::initCoeffsTable: Cannot allocate memory for vertical parts rectangulars", err); + } + + + initCoefTableKernel<<<128,128,128, *stream>>>(a, coeffs_table, kTableSize); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build("helpers::initCoeffsTable: Cannot syncronize kernel", err); + } + + return coeffs_table; + } +// _CUDA_HD const float* getCoeffsTable(const bool use_keys_cubic) { +// // Static so that we initialize it on first use +// if (use_keys_cubic) { +// // http://ieeexplore.ieee.org/document/1163711/ +// // R. G. Keys. Cubic convolution interpolation for digital image +// // processing. IEEE Transactions on Acoustics, Speech, and Signal +// // Processing, 29(6):1153–1160, 1981. +// //static const float* coeffs_table = initCoeffsTable(-0.5f, stream); +// return sCoeffsTableHalf; +// } else { +// //static const float* coeffs_table = initCoeffsTable(-0.75f, stream); +// return sCoeffsTableThreeFourth; +// } +// } + + inline _CUDA_HD Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { + return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); + } + + + template + inline _CUDA_HD float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3, + const T value0, const T value1, const T value2, const T value3) { + return static_cast(value0) * weight0 + + static_cast(value1) * weight1 + + static_cast(value2) * weight2 + + static_cast(value3) * weight3; + } + +// Compute the 1D interpolation for a given X index using the y_weights + static _CUDA_HD float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) { + return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]); + } + + + + template + inline _CUDA_HD void getWeightsAndIndices(float const* coeffs_table, const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) { + const Scaler scaler; + const float in_loc_f = scaler(out_loc, scale); + const Nd4jLong in_loc = math::nd4j_floor(in_loc_f); + const float delta = in_loc_f - in_loc; + const Nd4jLong offset = math::nd4j_round(delta * kTableSize); + //const float* coeffs_table = getCoeffsTable(use_keys_cubic); + if (use_keys_cubic) { + // The legacy code placed more weight on the edge pixels, since bounding + // the set of inputs to sample could cause an edge pixel to be repeated. + // Here we change the behavior at borders to match that used by the + // scale_and_translate_op, where sampling locations outside the image have + // their weight set to 0, and the weights are renormalized so that their sum + // is 1.0. + out->_index0 = bound(in_loc - 1, limit); + out->_weight0 = + (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); + out->_index1 = bound(in_loc, limit); + out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); + out->_index2 = bound(in_loc + 1, limit); + out->_weight2 = + (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] + : 0.0f); + out->_index3 = bound(in_loc + 2, limit); + out->_weight3 = (out->_index3 == in_loc + 2 + ? coeffs_table[(kTableSize - offset) * 2 + 1] + : 0.0f); + + const float weight_sum = + out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; + if (math::nd4j_abs(weight_sum) >= 1000.0f * DataTypeUtils::min()) { + const float one_over_weight_sum = 1.0f / weight_sum; + out->_weight0 *= one_over_weight_sum; + out->_weight1 *= one_over_weight_sum; + out->_weight2 *= one_over_weight_sum; + out->_weight3 *= one_over_weight_sum; + } + } else { + out->_weight0 = coeffs_table[offset * 2 + 1]; + out->_weight1 = coeffs_table[offset * 2]; + out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; + out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->_index0 = bound(in_loc - 1, limit); + out->_index1 = bound(in_loc, limit); + out->_index2 = bound(in_loc + 1, limit); + out->_index3 = bound(in_loc + 2, limit); + } + } + +// Older incorrect scaling method that causes all resizes to have a slight +// translation leading to inconsistent results. For example, a flip then a +// resize gives different results then a resize then a flip. + struct LegacyScaler { + _CUDA_HD LegacyScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } + }; + + static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto x = start; x < outWidth; x += step) { + pXWais[x]._index0 *= channels; + pXWais[x]._index1 *= channels; + pXWais[x]._index2 *= channels; + pXWais[x]._index3 *= channels; + } + } + + static __global__ void advaceWeightsAndIndicesKernel(float const* cacheTable, CachedInterpolationCalculator* calc, WeightsAndIndices* pXWais, Nd4jLong inWidth, float widthScale, + Nd4jLong outWidth, Nd4jLong channels, bool halfPixelCenters) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto x = start; x < outWidth; x += step) { + if (halfPixelCenters) + getWeightsAndIndices(cacheTable, widthScale, x, inWidth, &pXWais[x]); + else + getWeightsAndIndices(cacheTable, widthScale, x, inWidth, &pXWais[x]); + pXWais[x]._advance = calc->Advance(pXWais[x]._index0, pXWais[x]._index1, pXWais[x]._index2, pXWais[x]._index3); + } + } + // resizerState and xWais are device allocated + static void computeXWeightsAndIndices(float const* coeffsTable, const ImageResizerState& resizerState, + const bool halfPixelCenters, + WeightsAndIndices* pXWais) { + + auto stream = resizerState.stream; + auto outWidth = resizerState.outWidth; + CachedInterpolationCalculator calc; // = new CachedInterpolationCalculator; + CachedInterpolationCalculator* pCalcD; + auto err = cudaMalloc(&pCalcD, sizeof(CachedInterpolationCalculator)); + if (err != 0) { + cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot allocated device memory for interpolate calculator", err); + } + err = cudaMemcpy(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice); + if (err != 0) { + cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot set up device memory for interpolate calculator", err); + } + + advaceWeightsAndIndicesKernel<<<128, 128, 128, *stream>>>(coeffsTable, pCalcD, pXWais, resizerState.inWidth, resizerState.widthScale, outWidth, resizerState.channels, halfPixelCenters); + err = cudaFree(pCalcD); + if (err != 0) { + cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot deallocated device memory for interpolate calculator", err); + } + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after advance weights and indicers", err); + } + // Scale the values so they can be used as offsets into buffers. + accumulateChannelsKernel<<<128, 128, 512, *stream>>>(pXWais, outWidth, resizerState.channels); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after accumulate channels", err); + } + + } + + template + static _CUDA_HD FORCEINLINE float computeYInterpolation( + int which, int channelNum, const WeightsAndIndices& yWai, + const T* pY0, const T* pY1, const T* pY2, const T* pY3, + const WeightsAndIndices& xWai) { + int xIndex; + switch (which) { + case 0: + xIndex = xWai._index0; + break; + case 1: + xIndex = xWai._index1; + break; + case 2: + xIndex = xWai._index2; + break; + default: + xIndex = xWai._index3; + break; + } + const Nd4jLong pt_index = xIndex + channelNum; + return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, + yWai._weight3, pY0[pt_index], pY1[pt_index], + pY2[pt_index], pY3[pt_index]); + } + + template + static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, T* outputPtr) { +// auto numChannels = pResizerState->channels; + for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { + auto pInput = inputPtr + b * inBatchWidth; + for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) { + auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels; + auto pOutput = &outputPtr[pos]; + struct WeightsAndIndices yWai; + if (halfPixelCenters) { + getWeightsAndIndices(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai); + } else { + getWeightsAndIndices(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai); + } + // Make pointers represent offsets of data in inputBPtr. + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; + + if (pResizerState->channels == 3) { + // Manually unroll case of 3 channels. + float cached_value_0[4] = {0}; + float cached_value_1[4] = {0}; + float cached_value_2[4] = {0}; + for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cached_value_* to fill first '_advance' values. + switch (xWai._advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } + + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + cached_value_0[0] = computeYInterpolation(0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[0] = computeYInterpolation(0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[0] = computeYInterpolation(0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + case 1: + cached_value_0[1] = computeYInterpolation(1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[1] = computeYInterpolation(1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[1] = computeYInterpolation(1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + case 2: + cached_value_0[2] = computeYInterpolation(2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[2] = computeYInterpolation(2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[2] = computeYInterpolation(2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + case 3: + cached_value_0[3] = computeYInterpolation(3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[3] = computeYInterpolation(3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[3] = computeYInterpolation(3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + // break; + } + pOutput[x * pResizerState->channels + 0] = compute(cached_value_0, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * pResizerState->channels + 1] = compute(cached_value_1, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * pResizerState->channels + 2] = compute(cached_value_2, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } else { + for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cachedValue to fill first '_advance' values. + switch (xWai._advance) { + case 3: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; + } + break; + case 2: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; + } + break; + case 1: { + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; + } + break; + } + } + + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = computeYInterpolation(0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 1: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 1] = computeYInterpolation(1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 2: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 2] = computeYInterpolation(2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 3: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 3] = computeYInterpolation(3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + // break; + } + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + pOutput[x * pResizerState->channels + c] = compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, xWai._weight2, xWai._weight3); + } + } + } + } + } + + } + + + template + static void + bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { + const auto numChannels = resizerState.channels; + const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; + const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + + auto stream = resizerState.stream; //output->getContext()->getCudaStream(); + ImageResizerState* resizerStateD; + auto err = cudaMalloc(&resizerStateD, sizeof(ImageResizerState)); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for resizerState", err); + } + err = cudaMemcpy(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err); + } + + float* cachedValue = nullptr; + size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); + if (cachedSize) { + err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); + } + err = cudaMemset(cachedValue, 0, cachedSize); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); + } + } + + WeightsAndIndices* xWais; //(resizerState.outWidth); + err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for weights and indices", err); + } + + auto coeffsTable = halfPixelCenters?initCoeffsTable(-0.5, stream): initCoeffsTable(-0.75, stream); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); + } + computeXWeightsAndIndices(coeffsTable, resizerState, halfPixelCenters, xWais); + err = cudaStreamQuery(*stream); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); + } + const T* pInput = image->getDataBuffer()->specialAsT(); + T* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); + bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput, + resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Kernels finished with error", err); + } + + err = cudaFree(resizerStateD); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err); + } + if (cachedSize) + err = cudaFree(cachedValue); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); + } + + err = cudaFree(xWais); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for weights and indices", err); + } + + err = cudaFree(coeffsTable); + if (err != 0) { + throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for coefficients table", err); + } + + } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + int resizeBicubicFunctor_(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool preserveAspectRatio, bool antialias, NDArray* output) { + return Status::OK(); + } + + int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool preserveAspectRatio, bool antialias, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image, + width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, (nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool preserveAspectRatio, bool antialias, NDArray* output), NUMERIC_TYPES); +// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ // +// simplified bicubic resize without antialiasing +// + template + int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool const alignCorners, bool const halfPixelCenters, NDArray* output) { + + ImageResizerState st(alignCorners, halfPixelCenters); // align_corners, half_pixel_align + st.stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {image}); + int res = st.validateAndCreateOutput(image, width, height); + if (res == Status::OK()) + bicubicInterpolateWithCaching(image, st, halfPixelCenters, output); + NDArray::registerSpecialUse({output}, {image}); + return res; + } + + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool const alignCorners, bool const halfPixelCenters, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, + image, width, height, alignCorners, halfPixelCenters, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, (nd4j::LaunchContext * context, + NDArray const* image, int width, int height, bool const alignCorners, bool const halfPixelCenters, NDArray* output), NUMERIC_TYPES); + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { + switch (method) { + case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break; + case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break; + case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; + case kResizeLanczos5: + case kResizeGaussian: + case kResizeArea: + case kResizeMitchelcubic: + throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); + } + return ND4J_STATUS_OK; + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // --------------------------------------------------------------------------------------------------------------- // // Crop and Resize helper implementation - // --------------------------------------------------------------------------------------------------------------- // - // cropAndResize kernel + // -------------------------------------------------------------------------------------------------------------- // + // cropAndResize kernel type of input(images) and output should be the same // template static __global__ void cropAndResizeKernel(T const *images, Nd4jLong* imagesShape, Z const* boxes, Nd4jLong* boxesShape, I const* indices, Nd4jLong* indexShape, I const* cropSize, Nd4jLong* cropShape, int method, - double extrapolationVal, Z* output, Nd4jLong* outputShape, int numBoxes, int cropHeight, int cropWidth, + double extrapolationVal, T* output, Nd4jLong* outputShape, int numBoxes, int cropHeight, int cropWidth, int batchSize, int imageHeight, int imageWidth, int depth) { for (int b = blockIdx.x; b < numBoxes; b += gridDim.x) @@ -444,7 +1067,7 @@ namespace helpers { Z const* boxesBuf = reinterpret_cast(boxes->getSpecialBuffer()); I const* indexBuf = reinterpret_cast(indices->getSpecialBuffer()); I const* cropSizes = reinterpret_cast(cropSize->getSpecialBuffer()); - Z* outBuf = reinterpret_cast(crops->specialBuffer()); + T* outBuf = reinterpret_cast(crops->specialBuffer()); NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); cropAndResizeKernel<<>>(imagesBuf, images->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), indexBuf, indices->getSpecialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu index 3b854415b..24116ec63 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu @@ -215,6 +215,10 @@ namespace nd4j { } else { addInfVectorKernel<<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); } + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build("helpers::skipgram_: Cannot synchronize stream after addInfVectorKernel", err); + } err = cudaFree(neu1e); if (0 != err) { @@ -317,10 +321,15 @@ namespace nd4j { } } addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot synchronize stream after addInfVectorKernel", err); + } // optionally release temp arrays err = cudaFree(neu1e); if (err != 0) { + throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot deallocate memory with stage", err); break; } // if (vectorLength > 600) @@ -485,9 +494,22 @@ namespace nd4j { infVector[i] += neu1e[i]; } } - + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot synchronize stream after kernel executing", err); + } err = cudaFree(neu1); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot deallocate memory for synonims table", err); + } + err = cudaFree(neu1e); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot deallocate memory for antonims table", err); + } } BUILD_SINGLE_TEMPLATE(template void cbow_, (LaunchContext* lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES); @@ -574,13 +596,13 @@ namespace nd4j { const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); const auto numTargets = context.sizeAt(0); const int contextWidth = context.sizeAt(1); - const auto bContext = reinterpret_cast(context.buffer()); //bufferAsT(); - const auto dContext = reinterpret_cast(context.specialBuffer()); //bufferAsT(); - const auto bLocker = reinterpret_cast(lockedWords.buffer()); //lockedWords.bufferAsT(); - const auto dLocker = reinterpret_cast(lockedWords.specialBuffer()); //lockedWords.bufferAsT(); - const auto bIndices = reinterpret_cast(indices.buffer());//AsT(); - const auto bCodes = reinterpret_cast(codes.buffer()); //bufferAsT(); - const auto bStarters = reinterpret_cast(negStarters.buffer()); //AsT(); + //const auto bContext = reinterpret_cast(context.buffer()); //bufferAsT(); + const auto dContext = context.dataBuffer()->specialAsT(); //bufferAsT(); +// const auto bLocker = reinterpret_cast(lockedWords.buffer()); //lockedWords.bufferAsT(); + const auto dLocker = lockedWords.dataBuffer()->specialAsT(); //.specialBuffer()); //lockedWords.bufferAsT(); + const auto bIndices = indices.dataBuffer()->primaryAsT(); //buffer());//AsT(); + const auto bCodes = codes.dataBuffer()->primaryAsT(); //reinterpret_cast(codes.buffer()); //bufferAsT(); + const auto bStarters = negStarters.dataBuffer()->primaryAsT(); //reinterpret_cast(negStarters.buffer()); //AsT(); const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); lr.syncToHost(); nLabels.syncToHost(); @@ -678,6 +700,11 @@ namespace nd4j { // } } + cerr = cudaStreamSynchronize(*stream); + if (cerr) { + throw cuda_exception::build("Cannot syncronize stream before memory deallocation", cerr); + } + cerr = cudaFree(neu1); if (cerr) { throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 8a9986e23..1a5a255ee 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/libnd4j/include/ops/declarable/helpers/image_resize.h b/libnd4j/include/ops/declarable/helpers/image_resize.h index aa1b75986..22c41833b 100644 --- a/libnd4j/include/ops/declarable/helpers/image_resize.h +++ b/libnd4j/include/ops/declarable/helpers/image_resize.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -26,9 +27,29 @@ namespace nd4j { namespace ops { namespace helpers { - int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, NDArray* output); - int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, NDArray* output); - void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops); + enum ImageResizeMethods { + kResizeBilinear = 1, + kResizeBicubic, + kResizeNearest, + kResizeGaussian, + kResizeLanczos5, + kResizeMitchelcubic, + kResizeArea + }; + + int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, + NDArray* output); + int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, + NDArray* output); + int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool preserveAspectRatio, bool antialias, NDArray* output); + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + bool const alignCorners, bool const halfPixelAlign, NDArray* output); + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output); + + void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes, + NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops); } } } diff --git a/libnd4j/include/ops/declarable/helpers/s_t_b.h b/libnd4j/include/ops/declarable/helpers/s_t_b.h index e761905fd..35e57a1e8 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_b.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_b.h @@ -50,7 +50,7 @@ namespace helpers { Nd4jLong space_shape[NUM_BLOCK_DIMS]; Nd4jLong batch_shape[NUM_BLOCK_DIMS]; - const int batch_size = batch->sizeAt(0); + const int batchSize = batch->sizeAt(0); const int space_size = space->sizeAt(0); #pragma unroll @@ -65,7 +65,7 @@ namespace helpers { auto batch_strides = batch->stridesOf(); // TODO: this loop should be moved to _execute phase - for (int batch_b = 0; batch_b < batch_size; ++batch_b) { + for (int batch_b = 0; batch_b < batchSize; ++batch_b) { const Nd4jLong space_b = batch_b % space_size; Nd4jLong block_index = batch_b / space_size; Nd4jLong block_offsets[NUM_BLOCK_DIMS]; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 5ee19b007..3aef09bcd 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -25,6 +25,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -227,6 +229,15 @@ namespace nd4j { nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second); throw std::runtime_error("Expected vs provided shapes mismatch"); } + + /* + * FIXME: we want to uncomment this eventually, and check data types equality + //checking out data type equality + if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) { + std::string msg = "Provided array [" + StringUtils::valueToString(pair.second) + "] has unexpected data type"; + throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape)); + } + */ } } else { auto fout = ctx.fastpath_out(); @@ -237,6 +248,7 @@ namespace nd4j { ctx.setOutputArray(idx, outArr, true); } else { auto array = fout[idx]; + // checking out shape equality if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); @@ -247,6 +259,15 @@ namespace nd4j { nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx); throw std::runtime_error("Expected vs provided shape mismatch"); } + + /* + * FIXME: we want to uncomment this eventually, and check data types equality + //checking out data type equality + if (ArrayOptions::dataType(out) != array->dataType()) { + std::string msg = "Provided array [" + StringUtils::valueToString(idx) + "] has unexpected data type"; + throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType()); + } + */ } } } @@ -912,6 +933,10 @@ namespace nd4j { return ND4J_STATUS_OK; } + samediff::EmptyHandling DeclarableOp::emptyHandling() { + return samediff::EmptyHandling::EMPTY_SKIP; + } + void DeclarableOp::registerTypes() { this->getOpDescriptor()->setSameMode(true); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 927c53e40..e70aff9d9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -82,11 +82,11 @@ namespace nd4j { auto poolingMode = PoolingType::AVG_POOL; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true, bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, @@ -99,20 +99,20 @@ namespace nd4j { pool_strides, pool_kernel, pool_padding, pool_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto pool_src_memory = user_src_memory; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); } auto pool_dst_memory = user_dst_memory; if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); } - pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, - {MKLDNN_ARG_DST, pool_dst_memory}}); + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp index bf7b11b70..ceef28d33 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -89,11 +89,11 @@ namespace nd4j { auto poolingMode = PoolingType::AVG_POOL; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true, bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, @@ -109,20 +109,20 @@ namespace nd4j { auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); + auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); auto poolB_src_memory = userB_src_memory; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); } auto poolB_dst_memory = userB_dst_memory; if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); } - pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, - {MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index 74fe7c8de..e42cb6a8e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -86,11 +86,11 @@ namespace nd4j { auto poolingMode = PoolingType::AVG_POOL; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, @@ -102,21 +102,21 @@ namespace nd4j { pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto pool_src_memory = user_src_memory; if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); } auto pool_dst_memory = user_dst_memory; if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); } - pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, - {MKLDNN_ARG_DST, pool_dst_memory}}); + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp index cc57e671d..370f2b3fd 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp @@ -26,7 +26,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -92,11 +92,11 @@ namespace nd4j { auto poolingMode = PoolingType::AVG_POOL; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, @@ -111,24 +111,24 @@ namespace nd4j { auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); - auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); + auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); auto poolB_src_memory = userB_src_memory; if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); } auto poolB_dst_memory = userB_dst_memory; if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); } - pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, - {MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 27f836a0e..23957aeb7 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -39,7 +39,7 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) { - // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) + // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) // also it gives wrong results for formats nhwc and ndhwc // x -> 2D:nc, 4D:nchw, 5D:ncdhw @@ -53,35 +53,35 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // input type - mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; + dnnl::memory::data_type type = dnnl::memory::data_type::f32; // indicate whether gamma or/and beta are given - auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch + auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch if (weights != nullptr) - flags |= mkldnn::normalization_flags::use_scale_shift; + flags |= dnnl::normalization_flags::use_scale_shift; - mkldnn::memory::dims dims; - mkldnn::memory::format_tag format; + dnnl::memory::dims dims; + dnnl::memory::format_tag format; if(xRank == 2) { dims = {x->sizeAt(0), x->sizeAt(1)}; - format = mkldnn::memory::format_tag::nc; + format = dnnl::memory::format_tag::nc; } else if(xRank == 4) { dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; - format = mkldnn::memory::format_tag::nchw; + format = dnnl::memory::format_tag::nchw; } else { // xRank = 5 dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; - format = mkldnn::memory::format_tag::ncdhw; + format = dnnl::memory::format_tag::ncdhw; } // memory descriptors for arrays // x - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); - mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; if(xRank > 2) { @@ -92,9 +92,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; // z, output - mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format); - mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format); - z_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); + z_user_md.data.format_kind = dnnl_blocked; // overrides format z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0]; z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1]; if(xRank > 2) { @@ -106,53 +106,53 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // batchnorm forward description - mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); - mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory and check whether reorder is required // x - auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; + auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) - mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[MKLDNN_ARG_SRC] = x_mkl_mem; + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; // z - auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer()); + auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; + auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; if (zReorder) - mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); - args[MKLDNN_ARG_DST] = z_mkl_mem; + dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); + args[DNNL_ARG_DST] = z_mkl_mem; // mean - auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer()); - args[MKLDNN_ARG_MEAN] = mean_mkl_mem; + auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer()); + args[DNNL_ARG_MEAN] = mean_mkl_mem; // variance - auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer()); - args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; + auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer()); + args[DNNL_ARG_VARIANCE] = var_mkl_mem; // gamma and beta (and their gradients) if they are present if(weights != nullptr) { - auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer()); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer()); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; } // run calculations - mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); + dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); // reorder outputs if necessary if (zReorder) - mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); stream.wait(); @@ -164,7 +164,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, const float epsilon, NDArray* dLdI, NDArray* dLdW) { - // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) + // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) // also it gives wrong results for formats nhwc and ndhwc // x -> 2D:nc, 4D:nchw, 5D:ncdhw @@ -180,35 +180,35 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // input type - mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; + dnnl::memory::data_type type = dnnl::memory::data_type::f32; // indicate whether gamma or/and beta are given - auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch + auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch if (weights != nullptr) - flags |= mkldnn::normalization_flags::use_scale_shift; + flags |= dnnl::normalization_flags::use_scale_shift; - mkldnn::memory::dims dims; - mkldnn::memory::format_tag format; + dnnl::memory::dims dims; + dnnl::memory::format_tag format; if(xRank == 2) { dims = {x->sizeAt(0), x->sizeAt(1)}; - format = mkldnn::memory::format_tag::nc; + format = dnnl::memory::format_tag::nc; } else if(xRank == 4) { dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; - format = mkldnn::memory::format_tag::nchw; + format = dnnl::memory::format_tag::nchw; } else { // xRank = 5 dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; - format = mkldnn::memory::format_tag::ncdhw; + format = dnnl::memory::format_tag::ncdhw; } // memory descriptors for arrays // x - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); - mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; if(xRank > 2) { @@ -219,9 +219,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; // dLdO - mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format); - mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format); - dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); + dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0]; dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1]; if(xRank > 2) { @@ -232,9 +232,9 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4]; // dLdI - mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format); - mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format); - dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); + dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0]; dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1]; if(xRank > 2) { @@ -245,66 +245,66 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4]; // batchnorm forward description - mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); - mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); // batchnorm backprop description - mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); - mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + dnnl::batch_normalization_backward::desc op_bp_desc(dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory and check whether reorder is required // x - auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; + auto x_mkl_mem = xReorder ? dnnl::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) - mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[MKLDNN_ARG_SRC] = x_mkl_mem; + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; // dLdO - auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer()); + auto dLdO_user_mem = dnnl::memory(dLdO_user_md, engine, dLdO->getBuffer()); const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc(); - auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem; + auto dLdO_mkl_mem = dLdOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem; if (dLdOReorder) - mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); - args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem; + dnnl::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); + args[DNNL_ARG_DIFF_DST] = dLdO_mkl_mem; // mean - auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); - args[MKLDNN_ARG_MEAN] = mean_mkl_mem; + auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); + args[DNNL_ARG_MEAN] = mean_mkl_mem; // variance - auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer()); - args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; + auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer()); + args[DNNL_ARG_VARIANCE] = var_mkl_mem; // dLdI - auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer()); + auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->getBuffer()); const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc(); - auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem; - args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem; + auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem; + args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem; // gamma and beta (and their gradients) if they are present if(weights != nullptr) { - auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer()); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer()); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; - auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer()); - args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; + auto dLdW_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer()); + args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; } // run calculations - mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); + dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); // reorder outputs if necessary if (dLdIReorder) - mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); + dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); stream.wait(); @@ -532,37 +532,37 @@ PLATFORM_CHECK(batchnorm) { // weights({1, 2, 0, 0}).assign(0.0f); // mkldnn_memory_desc_t empty; -// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); +// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); -// auto flag = mkldnn::normalization_flags::use_global_stats; +// auto flag = dnnl::normalization_flags::use_global_stats; // if (applyScale || applyOffset) -// flag |= mkldnn::normalization_flags::use_scale_shift; +// flag |= dnnl::normalization_flags::use_scale_shift; // mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, // &batchnorm_src_md, nullptr, &batchnorm_dst_md, // &user_src_md, nullptr, &user_dst_md, axes[0]); -// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); +// auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); // auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); -// mkldnn::stream stream(engine); -// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); -// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); -// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); -// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, +// dnnl::stream stream(engine); +// auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); +// auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); +// auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); +// auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine, // mean->buffer()); -// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, +// auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine, // variance->buffer()); // auto batchnorm_src_memory = user_src_memory; -// mkldnn::memory m(batchnorm_src_md, engine); +// dnnl::memory m(batchnorm_src_md, engine); // if (m.get_desc() != user_src_memory.get_desc()) { -// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); -// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, +// batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine); +// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, // batchnorm_src_memory); // } // auto batchnorm_dst_memory = user_dst_memory; // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { -// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); +// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine); // } // if (applyScale || applyOffset) { // if (gamma != nullptr) { @@ -572,22 +572,22 @@ PLATFORM_CHECK(batchnorm) { // weights({1, 2, 0, 0}).assign(beta); // } -// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); -// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, +// auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); +// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream, // {{MKLDNN_ARG_SRC, batchnorm_src_memory}, // {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, // {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, // {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, // {MKLDNN_ARG_DST, batchnorm_dst_memory}}); // } else { -// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, +// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream, // {{MKLDNN_ARG_SRC, batchnorm_src_memory}, // {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, // {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, // {MKLDNN_ARG_DST, batchnorm_dst_memory}}); // } // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { -// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, +// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, // user_dst_memory); // } // stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index b6e7de7da..027c62484 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -47,12 +47,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con if(isSameMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( + dnnl_memory_desc_t empty; + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty); - mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( + dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); - mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output, @@ -74,38 +74,38 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = mkldnn::memory(user_weights_md, engine, + auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); + auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { - conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine); + conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); } auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine); + conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory); } auto conv_dst_memory = user_dst_memory; if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine); + conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine); } if (bias != nullptr) { - auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, + auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, const_cast(bias)->buffer()); - convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, - {MKLDNN_ARG_WEIGHTS, conv_weights_memory}, - {MKLDNN_ARG_BIAS, conv_bias_memory}, - {MKLDNN_ARG_DST, conv_dst_memory}}); + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, + {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, conv_bias_memory}, + {DNNL_ARG_DST, conv_dst_memory}}); } else { - convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, - {MKLDNN_ARG_WEIGHTS, conv_weights_memory}, - {MKLDNN_ARG_DST, conv_dst_memory}}); + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, + {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_DST, conv_dst_memory}}); } if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); @@ -198,12 +198,12 @@ PLATFORM_IMPL(conv2d_bp) { if (isSameMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), + dnnl_memory_desc_t empty; + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); - mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO, @@ -235,47 +235,47 @@ PLATFORM_IMPL(conv2d_bp) { conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); - auto userW_src_memory = mkldnn::memory(user_src_md, engine, + auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = mkldnn::memory(user_dst_md, engine, + auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); + auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { - convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine); + convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory); } auto convW_weights_memory = userW_weights_memory; if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine); + convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); } auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine); + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine, + auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); convolution_backward_weights(convW_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, convW_src_memory}, - {MKLDNN_ARG_DIFF_DST, convW_dst_memory}, - {MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}}); + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, + {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); } else { convolution_backward_weights(convW_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, convW_src_memory}, - {MKLDNN_ARG_DIFF_DST, convW_dst_memory}, - {MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); } if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { @@ -293,38 +293,38 @@ PLATFORM_IMPL(conv2d_bp) { conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); - auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = mkldnn::memory(user_weights_md, engine, + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); - auto userI_dst_memory = mkldnn::memory(user_dst_md, engine, + auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine); + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); } auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine); + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine); + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, - {{MKLDNN_ARG_DIFF_DST, convI_dst_memory}, - {MKLDNN_ARG_WEIGHTS, convI_weights_memory}, - {MKLDNN_ARG_DIFF_SRC, convI_src_memory}}); + {{DNNL_ARG_DIFF_DST, convI_dst_memory}, + {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, convI_src_memory}}); if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 3d9a79535..a5871fada 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -83,12 +83,12 @@ PLATFORM_IMPL(conv3dnew) { ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( + dnnl_memory_desc_t empty; + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty); - mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( + dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); - mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, @@ -110,37 +110,37 @@ PLATFORM_IMPL(conv3dnew) { conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = mkldnn::memory(user_weights_md, engine, + auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); + auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { - conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine); + conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); } auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine); + conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory); } auto conv_dst_memory = user_dst_memory; if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine); + conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine); } if (bias != nullptr) { - auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); - convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, - {MKLDNN_ARG_WEIGHTS, conv_weights_memory}, - {MKLDNN_ARG_BIAS, conv_bias_memory}, - {MKLDNN_ARG_DST, conv_dst_memory}}); + auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, + {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, conv_bias_memory}, + {DNNL_ARG_DST, conv_dst_memory}}); } else { - convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory}, - {MKLDNN_ARG_WEIGHTS, conv_weights_memory}, - {MKLDNN_ARG_DST, conv_dst_memory}}); + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, + {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_DST, conv_dst_memory}}); } if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); @@ -235,12 +235,12 @@ PLATFORM_IMPL(conv3dnew_bp) { oC, bias->rankOf(), bias->lengthOf()); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), + dnnl_memory_desc_t empty; + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); - mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, isNDHWC, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, @@ -273,47 +273,47 @@ PLATFORM_IMPL(conv3dnew_bp) { conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); - auto userW_src_memory = mkldnn::memory(user_src_md, engine, + auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = mkldnn::memory(user_dst_md, engine, + auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); + auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { - convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine); + convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory); } auto convW_weights_memory = userW_weights_memory; if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine); + convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); } auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine); + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine, + auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); convolution_backward_weights(convW_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, convW_src_memory}, - {MKLDNN_ARG_DIFF_DST, convW_dst_memory}, - {MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}}); + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, + {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); } else { convolution_backward_weights(convW_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, convW_src_memory}, - {MKLDNN_ARG_DIFF_DST, convW_dst_memory}, - {MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); } if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { @@ -330,38 +330,38 @@ PLATFORM_IMPL(conv3dnew_bp) { conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); - auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = mkldnn::memory(user_weights_md, engine, + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); - auto userI_dst_memory = mkldnn::memory(user_dst_md, engine, + auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine); + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); } auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine); + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine); + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, - {{MKLDNN_ARG_DIFF_DST, convI_dst_memory}, - {MKLDNN_ARG_WEIGHTS, convI_weights_memory}, - {MKLDNN_ARG_DIFF_SRC, convI_src_memory}}); + {{DNNL_ARG_DIFF_DST, convI_dst_memory}, + {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, convI_src_memory}}); if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 239e243ca..047549e40 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -46,80 +46,77 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW); - ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl); - - mkldnn::memory::dims strides = { sH, sW }; - mkldnn::memory::dims padding = { pH, pW }; - mkldnn::memory::dims padding_r = { pHmkl, pWmkl }; - mkldnn::memory::dims dilation = { dHmkl, dWmkl }; + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; + dnnl::memory::dims dilation = { dH-1, dW-1 }; // input type - mkldnn::memory::data_type xType; + dnnl::memory::data_type xType; if(input->dataType() == DataType::FLOAT32) - xType = mkldnn::memory::data_type::f32; + xType = dnnl::memory::data_type::f32; else if(input->dataType() == DataType::HALF) - xType = mkldnn::memory::data_type::f16; + xType = dnnl::memory::data_type::f16; else if(input->dataType() == DataType::UINT8) - xType = mkldnn::memory::data_type::u8; + xType = dnnl::memory::data_type::u8; else - xType = mkldnn::memory::data_type::s8; + xType = dnnl::memory::data_type::s8; // weights type - mkldnn::memory::data_type wType = xType; - if(xType == mkldnn::memory::data_type::u8) - wType = mkldnn::memory::data_type::s8; + dnnl::memory::data_type wType = xType; + if(xType == dnnl::memory::data_type::u8) + wType = dnnl::memory::data_type::s8; // output and bias type (have the same types) - mkldnn::memory::data_type zType; + dnnl::memory::data_type zType; if(output->dataType() == DataType::FLOAT32) - zType = mkldnn::memory::data_type::f32; + zType = dnnl::memory::data_type::f32; else if(output->dataType() == DataType::HALF) - zType = mkldnn::memory::data_type::f16; + zType = dnnl::memory::data_type::f16; else if(output->dataType() == DataType::UINT8) - zType = mkldnn::memory::data_type::u8; + zType = dnnl::memory::data_type::u8; else if(output->dataType() == DataType::INT8) - zType = mkldnn::memory::data_type::s8; + zType = dnnl::memory::data_type::s8; else - zType = mkldnn::memory::data_type::s32; + zType = dnnl::memory::data_type::s32; - mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; - mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw; + dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; - mkldnn::memory::dims xDims = {bS, iC, iH, iW}; - mkldnn::memory::dims wDims = {oC, iC, kH, kW}; - mkldnn::memory::dims zDims = {bS, oC, oH, oW}; + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; // memory descriptors for arrays // input - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; // weights - mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); - w_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; // bias - mkldnn::memory::desc b_mkl_md; + dnnl::memory::desc b_mkl_md; if(bias != nullptr) - b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x); + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); // output - mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat); - z_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); + z_user_md.data.format_kind = dnnl_blocked; // overrides format z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0]; z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1]; z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2]; @@ -128,51 +125,51 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // operation primitive description - mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, + dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); + dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory buffers and check whether reorder is required // input - auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) - mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[MKLDNN_ARG_SRC] = x_mkl_mem; + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; // weights - auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; + auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; if (wReorder) - mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; // bias if(bias != nullptr) { - auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer()); - args[MKLDNN_ARG_BIAS] = b_mkl_mem; + auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = b_mkl_mem; } // output - auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer()); + auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[MKLDNN_ARG_DST] = z_mkl_mem; + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; // run calculations - mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args); + dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary if (zReorder) - mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); stream.wait(); @@ -193,160 +190,157 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW); - ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl); - - mkldnn::memory::dims strides = { sH, sW }; - mkldnn::memory::dims padding = { pH, pW }; - mkldnn::memory::dims padding_r = { pHmkl, pWmkl }; - mkldnn::memory::dims dilation = { dHmkl, dWmkl }; + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; + dnnl::memory::dims dilation = { dH-1, dW-1 }; // input type - mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // weights type - mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradO type - mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradI type - mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradW type - mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradB type - mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32; + dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; - mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw; + dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; - mkldnn::memory::dims xDims = {bS, iC, iH, iW}; - mkldnn::memory::dims wDims = {oC, iC, kH, kW}; - mkldnn::memory::dims zDims = {bS, oC, oH, oW}; + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; // memory descriptors for arrays // input - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; // weights - mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); - w_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; // gradO - mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; // gradI - mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3]; // gradW - mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat); - gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0]; gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1]; gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2]; gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3]; // gradB - mkldnn::memory::desc gradB_mkl_md; + dnnl::memory::desc gradB_mkl_md; if(gradB != nullptr) - gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x); + gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // forward primitive description - mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); // backward data primitive description - mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); + dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); // backward weights primitive description - mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); + dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory buffers and check whether reorder is required // input - auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; + auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) - mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[MKLDNN_ARG_SRC] = x_mkl_mem; + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; // weights - auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; + auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; if (wReorder) - mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; // gradO - auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer()); + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; if (gradOReorder) - mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem; + dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; // gradI - auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer()); + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; // gradW - auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer()); + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer()); const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; // gradB if(gradB != nullptr) { - auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer()); - args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem; + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; } // run backward data calculations - mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); // run backward weights calculations - mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); + dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary if (gradIReorder) - mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); if (gradWReorder) - mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); stream.wait(); @@ -423,12 +417,17 @@ PLATFORM_CHECK(deconv2d) { auto output = INPUT_VARIABLE(0); + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const DataType xType = input->dataType(); const DataType wType = weights->dataType(); const DataType zType = output->dataType(); const DataType bType = bias != nullptr ? bias->dataType() : zType; - return block.isUseMKLDNN() && ( + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && + ( (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ); @@ -521,6 +520,9 @@ PLATFORM_CHECK(deconv2d_bp) { auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const DataType xType = input->dataType(); const DataType wType = weights->dataType(); @@ -530,7 +532,7 @@ PLATFORM_CHECK(deconv2d_bp) { const DataType gradWType = gradW->dataType(); const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 5a1ed7d72..b0e27240f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -39,52 +39,52 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] // gradO [bS, oH, oW, oC] - mkldnn::memory::dims strides = { sH, sW }; - mkldnn::memory::dims dilation = { dH - 1, dW - 1 }; - mkldnn::memory::dims padding = { pH, pW }; - mkldnn::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims dilation = { dH - 1, dW - 1 }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; // weights type - mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradO type - mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradI type - mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; - mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw; + dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; - mkldnn::memory::dims xDims = {bS, iC, iH, iW}; - mkldnn::memory::dims wDims = {oC, iC, kH, kW}; - mkldnn::memory::dims zDims = {bS, oC, oH, oW}; + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; // memory descriptors for arrays // input - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, gradOType, mkldnn::memory::format_tag::any); + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any); // weights - mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); - w_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; // gradO - mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; // gradI - mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; @@ -94,48 +94,48 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // forward primitive description - mkldnn::convolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); // backward data primitive description - mkldnn::convolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); + dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory buffers and check whether reorder is required // weights - auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; + auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; if (wReorder) - mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; // gradO - auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer()); + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; if (gradOReorder) - mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem; + dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; // gradI - auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer()); + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; // run backward data calculations - mkldnn::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); // reorder gradI if necessary if (gradIReorder) - mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index d1d7ca87f..51e15349b 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -47,57 +47,54 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW); - ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl); - - mkldnn::memory::dims strides = { sD, sH, sW }; - mkldnn::memory::dims padding = { pD, pH, pW }; - mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl }; - mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl }; + dnnl::memory::dims strides = { sD, sH, sW }; + dnnl::memory::dims padding = { pD, pH, pW }; + dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; + dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; // input type - mkldnn::memory::data_type xType; + dnnl::memory::data_type xType; if(input->dataType() == DataType::FLOAT32) - xType = mkldnn::memory::data_type::f32; + xType = dnnl::memory::data_type::f32; else if(input->dataType() == DataType::HALF) - xType = mkldnn::memory::data_type::f16; + xType = dnnl::memory::data_type::f16; else if(input->dataType() == DataType::UINT8) - xType = mkldnn::memory::data_type::u8; + xType = dnnl::memory::data_type::u8; else - xType = mkldnn::memory::data_type::s8; + xType = dnnl::memory::data_type::s8; // weights type - mkldnn::memory::data_type wType = xType; - if(xType == mkldnn::memory::data_type::u8) - wType = mkldnn::memory::data_type::s8; + dnnl::memory::data_type wType = xType; + if(xType == dnnl::memory::data_type::u8) + wType = dnnl::memory::data_type::s8; // output and bias type (have the same types) - mkldnn::memory::data_type zType; + dnnl::memory::data_type zType; if(output->dataType() == DataType::FLOAT32) - zType = mkldnn::memory::data_type::f32; + zType = dnnl::memory::data_type::f32; else if(output->dataType() == DataType::HALF) - zType = mkldnn::memory::data_type::f16; + zType = dnnl::memory::data_type::f16; else if(output->dataType() == DataType::UINT8) - zType = mkldnn::memory::data_type::u8; + zType = dnnl::memory::data_type::u8; else if(output->dataType() == DataType::INT8) - zType = mkldnn::memory::data_type::s8; + zType = dnnl::memory::data_type::s8; else - zType = mkldnn::memory::data_type::s32; + zType = dnnl::memory::data_type::s32; - mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; - mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw; + dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; - mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW}; - mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW}; - mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW}; + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; // memory descriptors for arrays // input - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; @@ -105,9 +102,9 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4]; // weights - mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); - w_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; @@ -115,14 +112,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4]; // bias - mkldnn::memory::desc b_mkl_md; + dnnl::memory::desc b_mkl_md; if(bias != nullptr) - b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x); + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); // output - mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat); - z_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); + z_user_md.data.format_kind = dnnl_blocked; // overrides format z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0]; z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1]; z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2]; @@ -132,51 +129,51 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // operation primitive description - mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, + dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); + dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory buffers and check whether reorder is required // input - auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) - mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[MKLDNN_ARG_SRC] = x_mkl_mem; + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; // weights - auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; + auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; if (wReorder) - mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; // bias if(bias != nullptr) { - auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer()); - args[MKLDNN_ARG_BIAS] = b_mkl_mem; + auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = b_mkl_mem; } // output - auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer()); + auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[MKLDNN_ARG_DST] = z_mkl_mem; + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; // run calculations - mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args); + dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary if (zReorder) - mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); stream.wait(); @@ -197,40 +194,37 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW); - ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl); - - mkldnn::memory::dims strides = { sD, sH, sW }; - mkldnn::memory::dims padding = { pD, pH, pW }; - mkldnn::memory::dims padding_r = { pDmkl, pHmkl, pWmkl }; - mkldnn::memory::dims dilation = { dDmkl, dHmkl, dWmkl }; + dnnl::memory::dims strides = { sD, sH, sW }; + dnnl::memory::dims padding = { pD, pH, pW }; + dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; + dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; // input type - mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // weights type - mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradO type - mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradI type - mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradW type - mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16; + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // gradB type - mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32; + dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; // isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; - mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw; + dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; - mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW}; - mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW}; - mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW}; + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; // memory descriptors for arrays // input - mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; @@ -238,9 +232,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4]; // weights - mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat); - w_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; @@ -248,9 +242,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4]; // gradO - mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; @@ -258,9 +252,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4]; // gradI - mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any); - mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; @@ -268,9 +262,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4]; // gradW - mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, wFormat); - mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat); - gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format + dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0]; gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1]; gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2]; @@ -278,85 +272,85 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4]; // gradB - mkldnn::memory::desc gradB_mkl_md; + dnnl::memory::desc gradB_mkl_md; if(gradB != nullptr) - gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x); + gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // forward primitive description - mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); // backward data primitive description - mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); + dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); // backward weights primitive description - mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); + dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // provide memory buffers and check whether reorder is required // input - auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; + auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; if (xReorder) - mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[MKLDNN_ARG_SRC] = x_mkl_mem; + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; // weights - auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer()); + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); - auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; + auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; if (wReorder) - mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); - args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; // gradO - auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer()); + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; if (gradOReorder) - mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem; + dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; // gradI - auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer()); + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; // gradW - auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer()); + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer()); const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; // gradB if(gradB != nullptr) { - auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer()); - args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem; + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; } // run backward data calculations - mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); // run backward weights calculations - mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); + dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary if (gradIReorder) - mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); if (gradWReorder) - mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); stream.wait(); @@ -437,12 +431,18 @@ PLATFORM_CHECK(deconv3d) { auto output = INPUT_VARIABLE(0); + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + const DataType xType = input->dataType(); const DataType wType = weights->dataType(); const DataType zType = output->dataType(); const DataType bType = bias != nullptr ? bias->dataType() : zType; - return block.isUseMKLDNN() && ( + return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && + ( (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ); @@ -538,6 +538,11 @@ PLATFORM_CHECK(deconv3d_bp) { auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + const DataType xType = input->dataType(); const DataType wType = weights->dataType(); const DataType gradOType = gradO->dataType(); @@ -546,7 +551,7 @@ PLATFORM_CHECK(deconv3d_bp) { const DataType gradWType = gradW->dataType(); const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp index aa4f9272a..41efe6524 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -44,8 +44,8 @@ namespace nd4j { double bias = T_ARG(0); int depth = INT_ARG(0); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty); + dnnl_memory_desc_t empty; + dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty); mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1); @@ -54,24 +54,24 @@ namespace nd4j { lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto lrn_src_memory = user_src_memory; if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { - lrn_src_memory = mkldnn::memory(lrn_prim_desc.src_desc(), engine); + lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine); reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory); } auto lrn_dst_memory = user_dst_memory; if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - lrn_dst_memory = mkldnn::memory(lrn_prim_desc.dst_desc(), engine); + lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine); } - lrn_forward(lrn_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, lrn_src_memory}, - {MKLDNN_ARG_DST, lrn_dst_memory}}); + lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory}, + {DNNL_ARG_DST, lrn_dst_memory}}); if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index a2667c9f2..50a349cc9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -21,7 +21,7 @@ #include #include "mkldnnUtils.h" -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -132,52 +132,52 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, + dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; // input type - mkldnn::memory::data_type xType; + dnnl::memory::data_type xType; if(x->dataType() == DataType::FLOAT32) - xType = mkldnn::memory::data_type::f32; + xType = dnnl::memory::data_type::f32; else if(x->dataType() == DataType::HALF) - xType = mkldnn::memory::data_type::f16; + xType = dnnl::memory::data_type::f16; else - xType = mkldnn::memory::data_type::u8; + xType = dnnl::memory::data_type::u8; // weights type - mkldnn::memory::data_type wType = xType; - if(xType == mkldnn::memory::data_type::u8) - wType = mkldnn::memory::data_type::s8; + dnnl::memory::data_type wType = xType; + if(xType == dnnl::memory::data_type::u8) + wType = dnnl::memory::data_type::s8; // bias type - mkldnn::memory::data_type bType = xType; - if(xType == mkldnn::memory::data_type::u8) - bType = mkldnn::memory::data_type::f32; + dnnl::memory::data_type bType = xType; + if(xType == dnnl::memory::data_type::u8) + bType = dnnl::memory::data_type::f32; // output type - mkldnn::memory::data_type hType; + dnnl::memory::data_type hType; if(h->dataType() == DataType::FLOAT32) - hType = mkldnn::memory::data_type::f32; + hType = dnnl::memory::data_type::f32; else if(h->dataType() == DataType::HALF) - hType = mkldnn::memory::data_type::f16; + hType = dnnl::memory::data_type::f16; else - hType = mkldnn::memory::data_type::u8; + hType = dnnl::memory::data_type::u8; // memory descriptors for arrays // x - x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any); - // x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc); - x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc); - x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); + // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc); + x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); + x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; // wx - wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any); - wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); - wx_user_md.data.format_kind = mkldnn_blocked; // overrides format + wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any); + wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo); + wx_user_md.data.format_kind = dnnl_blocked; // overrides format wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; @@ -185,9 +185,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; // wr - wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any); - wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); - wr_user_md.data.format_kind = mkldnn_blocked; // overrides format + wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any); + wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo); + wr_user_md.data.format_kind = dnnl_blocked; // overrides format wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; @@ -195,19 +195,19 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; // h - h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any); - // h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc); - h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc); - h_user_md.data.format_kind = mkldnn_blocked; // overrides format + h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any); + // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); + h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc); + h_user_md.data.format_kind = dnnl_blocked; // overrides format h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; // b if(b) { - b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any); - b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo); - b_user_md.data.format_kind = mkldnn_blocked; // overrides format + b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any); + b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo); + b_user_md.data.format_kind = dnnl_blocked; // overrides format b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; @@ -216,9 +216,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // hI if(hI) { - hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); - hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); - hI_user_md.data.format_kind = mkldnn_blocked; // overrides format + hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); + hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); + hI_user_md.data.format_kind = dnnl_blocked; // overrides format hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; @@ -227,9 +227,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // cI if(cI) { - cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); - cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); - cI_user_md.data.format_kind = mkldnn_blocked; // overrides format + cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); + cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); + cI_user_md.data.format_kind = dnnl_blocked; // overrides format cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; @@ -238,9 +238,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // hL if(hL) { - hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any); - hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); - hL_user_md.data.format_kind = mkldnn_blocked; // overrides format + hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any); + hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); + hL_user_md.data.format_kind = dnnl_blocked; // overrides format hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; @@ -248,9 +248,9 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* } if(cL) { - cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); - cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); - cL_user_md.data.format_kind = mkldnn_blocked; // overrides format + cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); + cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); + cL_user_md.data.format_kind = dnnl_blocked; // overrides format cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; @@ -262,92 +262,92 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); // lstm primitive description lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); // arguments (memory buffers) necessary for calculations - std::unordered_map args; + std::unordered_map args; // provide memory and check whether reorder is required // x - auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc(); - auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem; + auto x_lstm_mem = xReorder ? dnnl::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem; if (xReorder) reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem); - args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem; + args[DNNL_ARG_SRC_LAYER] = x_lstm_mem; // wx - auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer()); + auto wx_user_mem = dnnl::memory(wx_user_md, engine, Wx->getBuffer()); const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc(); - auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem; + auto wx_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem; if (wxReorder) reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem); - args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem; + args[DNNL_ARG_WEIGHTS_LAYER] = wx_lstm_mem; // wr - auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer()); + auto wr_user_mem = dnnl::memory(wr_user_md, engine, Wr->getBuffer()); const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc(); - auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem; + auto wr_lstm_mem = wxReorder ? dnnl::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem; if (wrReorder) reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem); - args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem; + args[DNNL_ARG_WEIGHTS_ITER] = wr_lstm_mem; // h - auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer()); + auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); - auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; - args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem; + auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; + args[DNNL_ARG_DST_LAYER] = h_lstm_mem; // b if(b) { - auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer()); + auto b_user_mem = dnnl::memory(b_user_md, engine, b->getBuffer()); const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc(); - auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem; + auto b_lstm_mem = bReorder ? dnnl::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem; if (bReorder) reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem); - args[MKLDNN_ARG_BIAS] = b_lstm_mem; + args[DNNL_ARG_BIAS] = b_lstm_mem; } // hI if(hI) { - auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer()); + auto hI_user_mem = dnnl::memory(hI_user_md, engine, hI->getBuffer()); const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc(); - auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem; + auto hI_lstm_mem = hIReorder ? dnnl::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem; if (hIReorder) reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem); - args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem; + args[DNNL_ARG_SRC_ITER] = hI_lstm_mem; } // cI if(cI) { - auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer()); + auto cI_user_mem = dnnl::memory(cI_user_md, engine, cI->getBuffer()); const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc(); - auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem; + auto cI_lstm_mem = cIReorder ? dnnl::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem; if (cIReorder) reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem); - args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem; + args[DNNL_ARG_SRC_ITER_C] = cI_lstm_mem; } bool hLReorder(false), cLReorder(false); - mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; + dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; // hL if(hL) { - hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer()); + hL_user_mem = dnnl::memory(hL_user_md, engine, hL->getBuffer()); hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); - hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; - args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem; + hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; + args[DNNL_ARG_DST_ITER] = hL_lstm_mem; } // cL if(cL) { - cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer()); + cL_user_mem = dnnl::memory(cL_user_md, engine, cL->getBuffer()); cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); - cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; - args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem; + cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; + args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem; } // run calculations diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 86115d723..4204b93d0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -82,11 +82,11 @@ namespace nd4j { auto poolingMode = PoolingType::MAX_POOL; int extraParam0 = 1; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true, @@ -102,23 +102,23 @@ namespace nd4j { auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto pool_src_memory = user_src_memory; - mkldnn::stream stream(engine); + dnnl::stream stream(engine); if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); } auto pool_dst_memory = user_dst_memory; if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); } - pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, - {MKLDNN_ARG_DST, pool_dst_memory}}); + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp index aaead1f26..0c663a59c 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp @@ -27,7 +27,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -89,11 +89,11 @@ namespace nd4j { auto poolingMode = PoolingType::MAX_POOL; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, true, @@ -109,44 +109,44 @@ namespace nd4j { pool_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = mkldnn::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); + auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); auto poolB_src_memory = userB_src_memory; if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); } auto poolB_dst_memory = userB_dst_memory; if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); } - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); auto pool_src_memory = user_src_memory; if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); } - auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine); + auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, - {MKLDNN_ARG_DST, pool_dst_memory}, - {MKLDNN_ARG_WORKSPACE, pool_workspace_memory}}); + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); // probably wrong, fix that - pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, - {MKLDNN_ARG_WORKSPACE, pool_workspace_memory}, - {MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index b77059f8f..72fb79709 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -26,7 +26,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -87,11 +87,11 @@ namespace nd4j { auto poolingMode = PoolingType::MAX_POOL; auto extraParam0 = 1; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true, @@ -106,24 +106,24 @@ namespace nd4j { pool_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto pool_src_memory = user_src_memory; if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); } auto pool_dst_memory = user_dst_memory; if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); } - pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, - {MKLDNN_ARG_DST, pool_dst_memory}}); + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp index af0be5897..b4c9f1ad5 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp @@ -26,7 +26,7 @@ #include "mkldnnUtils.h" #include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace ops { @@ -93,11 +93,11 @@ namespace nd4j { auto poolingMode = PoolingType::MAX_POOL; auto extraParam0 = 1; - mkldnn_memory_desc_t empty; - mkldnn::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - mkldnn::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - mkldnn::algorithm algorithm; + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0, true, @@ -115,44 +115,44 @@ namespace nd4j { auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); + dnnl::stream stream(engine); auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer()); - auto userB_dst_memory = mkldnn::memory(user_dst_md, engine, gradO->buffer()); + auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); auto poolB_src_memory = userB_src_memory; if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = mkldnn::memory(poolB_prim_desc.diff_src_desc(), engine); + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); } auto poolB_dst_memory = userB_dst_memory; if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = mkldnn::memory(poolB_prim_desc.diff_dst_desc(), engine); + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); } - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); auto pool_src_memory = user_src_memory; if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = mkldnn::memory(pool_prim_desc.src_desc(), engine); + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); } - auto pool_dst_memory = mkldnn::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = mkldnn::memory(pool_prim_desc.workspace_desc(), engine); + auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - pooling_forward(pool_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, pool_src_memory}, - {MKLDNN_ARG_DST, pool_dst_memory}, - {MKLDNN_ARG_WORKSPACE, pool_workspace_memory}}); - pooling_backward(poolB_prim_desc).execute(stream, {{MKLDNN_ARG_DIFF_DST, poolB_dst_memory}, - {MKLDNN_ARG_WORKSPACE, pool_workspace_memory}, - {MKLDNN_ARG_DIFF_SRC, poolB_src_memory}}); + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 084fb760b..07cf60ef9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -18,23 +18,22 @@ // @author saudet // -#include +#include #include "mkldnnUtils.h" -#include -using namespace mkldnn; +using namespace dnnl; namespace nd4j { namespace mkldnnUtils { void getMKLDNNMemoryDescPool2d( int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { - mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW }; - mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW }; + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; + dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; pool_strides = { sH, sW }; pool_kernel = { kH, kW }; @@ -45,14 +44,14 @@ namespace nd4j { algorithm = poolingMode == 0 ? algorithm::pooling_max : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - auto type = mkldnn::memory::data_type::f32; - auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; - auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" + auto type = dnnl::memory::data_type::f32; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; @@ -60,9 +59,9 @@ namespace nd4j { } if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; @@ -70,9 +69,9 @@ namespace nd4j { } if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; @@ -84,12 +83,12 @@ namespace nd4j { void getMKLDNNMemoryDescPool3d( int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { - mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; - mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; + dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; pool_strides = { sD, sH, sW }; pool_kernel = { kD, kH, kW }; @@ -101,14 +100,14 @@ namespace nd4j { algorithm = poolingMode == 0 ? algorithm::pooling_max : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - auto type = mkldnn::memory::data_type::f32; - auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; - auto supposed_to_be_any_format = mkldnn::memory::format_tag::nCdhw8c; // doesn't work with "any" + auto type = dnnl::memory::data_type::f32; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; @@ -117,9 +116,9 @@ namespace nd4j { } if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; @@ -128,9 +127,9 @@ namespace nd4j { } if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; @@ -145,38 +144,29 @@ namespace nd4j { int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, - mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, - mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) { - mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW }; - mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW }; - mkldnn::memory::dims conv_bias_tz = { oC }; - mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW }; - - int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW); - nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl); - - conv_strides = { sH, sW }; - conv_padding = { pH, pW }; - conv_padding_r = { pHmkl, pWmkl }; - conv_dilation = { dHmkl, dWmkl }; + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; + dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; + dnnl::memory::dims conv_bias_tz = { oC }; + dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; conv_strides = { sH, sW }; conv_padding = { pH, pW }; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; conv_dilation = { dH-1, dW-1}; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; - auto type = mkldnn::memory::data_type::f32; - auto format = isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; - auto formatw = mkldnn::memory::format_tag::hwio; + auto type = dnnl::memory::data_type::f32; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto formatw = dnnl::memory::format_tag::hwio; if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); - *user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; @@ -184,9 +174,9 @@ namespace nd4j { } if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); - *user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; @@ -194,9 +184,9 @@ namespace nd4j { } if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); - *user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio" + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; @@ -204,9 +194,9 @@ namespace nd4j { } if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); - *user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = hwio" + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; @@ -214,14 +204,14 @@ namespace nd4j { } if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any); - *user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x); + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); + *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); } if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any); - *user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); + *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; @@ -233,32 +223,29 @@ namespace nd4j { int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, - mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, - mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) { - mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; - mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; - mkldnn::memory::dims conv_bias_tz = { oC }; - mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; - - int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW); - nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl); + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; + dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; + dnnl::memory::dims conv_bias_tz = { oC }; + dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; conv_strides = { sD, sH, sW }; conv_padding = { pD, pH, pW }; - conv_padding_r = { pDmkl, pHmkl, pWmkl }; - conv_dilation = { dDmkl, dHmkl, dWmkl }; + conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + conv_dilation = { dD-1, dH-1, dW-1}; - auto type = mkldnn::memory::data_type::f32; - auto format = isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc; - auto formatw = mkldnn::memory::format_tag::dhwio; + auto type = dnnl::memory::data_type::f32; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + auto formatw = dnnl::memory::format_tag::dhwio; if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); - *user_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; @@ -267,9 +254,9 @@ namespace nd4j { } if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, mkldnn::memory::format_tag::any); - *user_diff_src_md = mkldnn::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; @@ -278,9 +265,9 @@ namespace nd4j { } if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); - *user_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio" + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; @@ -289,9 +276,9 @@ namespace nd4j { } if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, mkldnn::memory::format_tag::any); - *user_diff_weights_md = mkldnn::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = mkldnn_blocked; // overrides "formatw = dhwio" + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; @@ -300,14 +287,14 @@ namespace nd4j { } if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::any); - *user_bias_md = mkldnn::memory::desc({ conv_bias_tz }, type, mkldnn::memory::format_tag::x); + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); + *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); } if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, mkldnn::memory::format_tag::any); - *user_dst_md = mkldnn::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); + *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; @@ -318,23 +305,23 @@ namespace nd4j { // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - // mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, - // mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { + // dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, + // dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { // const Nd4jLong* shape = src->getShapeInfo(); // Nd4jLong rank = shape[0]; // Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one // Nd4jLong dim2 = axis >= 2 ? 1 : 2; // Nd4jLong dim3 = axis >= 3 ? 2 : 3; - // mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + // dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - // auto type = mkldnn::memory::data_type::f32; - // auto format = mkldnn::memory::format_tag::nchw; - // auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" + // auto type = dnnl::memory::data_type::f32; + // auto format = dnnl::memory::format_tag::nchw; + // auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" // if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { - // *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - // user_src_md->data.format_kind = mkldnn_blocked; // overrides format + // *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); + // user_src_md->data.format_kind = dnnl_blocked; // overrides format // user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; // user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; // user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; @@ -342,9 +329,9 @@ namespace nd4j { // } // if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { - // *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - // user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format + // *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); + // user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format // user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; // user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; // user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; @@ -352,9 +339,9 @@ namespace nd4j { // } // if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { - // *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - // user_dst_md->data.format_kind = mkldnn_blocked; // overrides format + // *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); + // user_dst_md->data.format_kind = dnnl_blocked; // overrides format // user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; // user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; // user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; @@ -364,23 +351,23 @@ namespace nd4j { void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { + dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { const Nd4jLong* shape = src->getShapeInfo(); long rank = shape[0]; long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one long dim2 = axis >= 2 ? 1 : 2; long dim3 = axis >= 3 ? 2 : 3; - mkldnn::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - auto type = mkldnn::memory::data_type::f32; - auto format = axis == 1 ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc; + auto type = dnnl::memory::data_type::f32; + auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; auto supposed_to_be_any_format = format; // doesn't work with "any" if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { - *lrn_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; + *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; @@ -388,9 +375,9 @@ namespace nd4j { } if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { - *lrn_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ lrn_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; + *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; @@ -398,9 +385,9 @@ namespace nd4j { } if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { - *lrn_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ lrn_src_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; + *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; @@ -408,8 +395,8 @@ namespace nd4j { } } - mkldnn::engine& getEngine(void *ptr) { - auto eng = reinterpret_cast(ptr); + dnnl::engine& getEngine(void *ptr) { + auto eng = reinterpret_cast(ptr); return *eng; } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 9bc13427e..1f9b9e010 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -23,7 +23,7 @@ #include #include -#include +#include #include #include #include @@ -89,47 +89,47 @@ namespace nd4j{ int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, - mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, - mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation); + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); void getMKLDNNMemoryDescConv3d( int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, - mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, - mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation); + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); void getMKLDNNMemoryDescPool2d( int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); void getMKLDNNMemoryDescPool3d( int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis); + dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* lrn_src_md, mkldnn::memory::desc* lrn_diff_src_md, mkldnn::memory::desc* lrn_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis); + dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); - mkldnn::engine& getEngine(void *ptr); + dnnl::engine& getEngine(void *ptr); } } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index ab4bfca90..f3f9f9699 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1718,6 +1718,32 @@ namespace simdOps { } }; + template + class Mish { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + return d1 * nd4j::math::nd4j_tanh(nd4j::math::nd4j_softplus(d1)); + } + }; + + template + class MishDerivative { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + auto ex = nd4j::math::nd4j_exp(d1); + auto e2x = ex * ex; + auto e3x = ex * ex * ex; + + return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex *(4 * d1 + 6))) / nd4j::math::nd4j_pow((2 * ex + e2x + 2), (X) 2.f); + } + }; + template class GELU { public: @@ -1954,7 +1980,7 @@ namespace simdOps { no_op_exec_special_same_cuda op_def static X op(X d1, X *params) { - return nd4j::math::softplus(d1); + return nd4j::math::nd4j_softplus(d1); } }; @@ -2276,54 +2302,66 @@ namespace simdOps { return old + opOutput; } - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X *extraParams) { - X compare = extraParams[0]; - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[2]); - //printf("value: %f; comp: %f; eps: %f; mode: %i;\n", (float) d1, (float) compare, (float) eps, mode); - - switch (mode) { - case 0: // equals - return nd4j::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; - case 1: // not equals - return nd4j::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; - case 2: // less_than - return d1 < compare ? 1 : 0; - case 3: // greater_than - return d1 > compare ? 1 : 0; - case 4: // less_or_equals_than - return d1 <= compare ? 1 : 0; - case 5: // greater_or_equals_than - return d1 >= compare ? 1 : 0; - case 6: // abs_less_than - return nd4j::math::nd4j_abs(d1) < compare ? 1 : 0; - case 7: // abs_greater_than - return nd4j::math::nd4j_abs(d1) > compare ? 1 : 0; - case 8: // is inf - return nd4j::math::nd4j_isinf(d1) ? 1 : 0; - case 9: // is nan - return nd4j::math::nd4j_isnan(d1) ? 1 : 0; - case 10: - return (d1 == compare) ? 1 : 0; - case 11: - return (d1 != compare) ? 1 : 0; - case 12: // abs_greater_or_equals_than - return nd4j::math::nd4j_abs(d1) >= compare ? 1 : 0; - case 13: // abs_less_or_equals_than - return nd4j::math::nd4j_abs(d1) <= compare ? 1 : 0; + op_def static Z op(X d1, X compare, X eps, int mode) { + switch (mode) { + case 0: // equals + return nd4j::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; + case 1: // not equals + return nd4j::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; + case 2: // less_than + return d1 < compare ? 1 : 0; + case 3: // greater_than + return d1 > compare ? 1 : 0; + case 4: // less_or_equals_than + return d1 <= compare ? 1 : 0; + case 5: // greater_or_equals_than + return d1 >= compare ? 1 : 0; + case 6: // abs_less_than + return nd4j::math::nd4j_abs(d1) < compare ? 1 : 0; + case 7: // abs_greater_than + return nd4j::math::nd4j_abs(d1) > compare ? 1 : 0; + case 8: // is inf + return nd4j::math::nd4j_isinf(d1) ? 1 : 0; + case 9: // is nan + return nd4j::math::nd4j_isnan(d1) ? 1 : 0; + case 10: + return (d1 == compare) ? 1 : 0; + case 11: + return (d1 != compare) ? 1 : 0; + case 12: // abs_greater_or_equals_than + return nd4j::math::nd4j_abs(d1) >= compare ? 1 : 0; + case 13: // abs_less_or_equals_than + return nd4j::math::nd4j_abs(d1) <= compare ? 1 : 0; case 14: // isFinite return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0; case 15: // isInfinite return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0; - default: - printf("Undefined match condition: [%i]\n", mode); - } + default: + printf("Undefined match condition: [%i]\n", mode); + } return d1; + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X compare, X *extraParams) { + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[0]); + + return op(d1, compare, eps, mode); + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X *extraParams) { + X compare = extraParams[0]; + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[2]); + + return op(d1, compare, eps, mode); } op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 23f6b342d..55cd4033a 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -74,6 +74,9 @@ namespace nd4j { template math_def inline Z nd4j_copysign(T val1, T val2); + template + math_def inline Z nd4j_softplus(T val); + //#ifndef __CUDACC__ template math_def inline Z nd4j_dot(X *x, Y *y, int length); @@ -159,7 +162,7 @@ namespace nd4j { math_def inline Z nd4j_sinh(T val); template - math_def inline Z softplus(T val) { + math_def inline Z nd4j_softplus(T val) { return nd4j_log((Z) 1.0f + nd4j_exp(val)); } diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 23208ce1f..3b49c735d 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -255,25 +255,6 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_6) { delete result; } -TYPED_TEST(TypedConvolutionTests1, TestAvgFF_TF) { - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - nd4j::ops::avgpool2d op; - auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - // z->printIndexedBuffer("z"); - // exp.printIndexedBuffer("e"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_7) { @@ -297,6 +278,75 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_7) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_8) { + + int bS=1, iH=6,iW=8, iC=2,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=6,oW=8; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, + 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, + 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, + 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, + 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, 0.608414, 0.956500, 0.390098}); + + NDArray weights('c', {kH, kW, iC, oC}, {0.07581716775894165, 0.8706002235412598, 0.29345420002937317, 0.5281786322593689, 0.10540834069252014, 0.3663792014122009, 0.17209206521511078, 0.6257694959640503}); + NDArray bias('c', {1, oC}, {0.7414038777351379, 0.8980839848518372}); + + NDArray expOutput('c', {bS, oC, oH, oW}, {1.112878, 1.106691, 0.914598, 1.127438, 0.988108, 1.070572, 1.040759, 0.962728, 0.927537, 1.109045, 0.893301, 1.101278, 1.080314, + 1.112327, 1.030041, 0.955914, 0.779137, 1.110499, 0.944709, 1.195986, 0.997814, 1.083822, 1.090898, 0.889572, 0.964781, 1.071012, 1.111928, 1.291319, 1.085454, 0.977661, + 1.149068, 1.077099, 1.068283, 1.064290, 1.177125, 1.212480, 0.932593, 0.939493, 1.118576, 1.056927, 0.780314, 0.845707, 0.996308, 0.963152, 0.906792, 0.937590, 1.048791, + 0.860346, 2.264212, 2.071576, 1.916629, 2.030785, 2.169075, 2.039786, 1.935480, 2.177816, 1.524273, 1.933327, 1.630923, 2.406983, 1.770406, 2.413284, 1.790349, 1.476586, + 1.179925, 1.909109, 2.009143, 2.299778, 1.957207, 1.779718, 2.480604, 1.529086, 1.748063, 1.952856, 2.029487, 2.699131, 1.879842, 1.471205, 2.150177, 2.039078, 1.933456, + 1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); + + nd4j::ops::conv2d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results->at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +TYPED_TEST(TypedConvolutionTests1, TestAvgFF_TF) { + + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, + 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, + -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, + -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, + 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, + 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, + -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, + 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, + -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, + 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, + 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); + + nd4j::ops::avgpool2d op; + auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + // z->printIndexedBuffer("z"); + // exp.printIndexedBuffer("e"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_1) { float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; @@ -436,37 +486,39 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_3) { delete result; } -TYPED_TEST(TypedConvolutionTests1, deconv2D_FF_NoBias_1) { +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, sconv2d_4) { - int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=8,oW=8; + int bS=1, iH=6,iW=6, iC=3,oC=2,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=6,oW=6; int paddingMode = 0; // 1-SAME, 0-VALID; int dataFormat = 0; // 1-NHWC, 0-NCHW - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1., 76., 151., 26., 101., 176., 51., 126., 201., 2., 77., 152., 27., 102., 177., 52., 127., 202., 3., 78., 153., 28., 103., 178., 53., 128., 203., - 4., 79., 154., 29., 104., 179., 54., 129., 204., 5., 80., 155., 30., 105., 180., 55., 130., 205., 6., 81., 156., 31., 106., 181., 56., 131., 206., - 7., 82., 157., 32., 107., 182., 57., 132., 207., 8., 83., 158., 33., 108., 183., 58., 133., 208., 9., 84., 159., 34., 109., 184., 59., 134., 209., - 10., 85., 160., 35., 110., 185., 60., 135., 210., 11., 86., 161., 36., 111., 186., 61., 136., 211., 12., 87., 162., 37., 112., 187., 62., 137., 212., - 13., 88., 163., 38., 113., 188., 63., 138., 213., 14., 89., 164., 39., 114., 189., 64., 139., 214., 15., 90., 165., 40., 115., 190., 65., 140., 215., - 16., 91., 166., 41., 116., 191., 66., 141., 216., 17., 92., 167., 42., 117., 192., 67., 142., 217., 18., 93., 168., 43., 118., 193., 68., 143., 218., - 19., 94., 169., 44., 119., 194., 69., 144., 219., 20., 95., 170., 45., 120., 195., 70., 145., 220., 21., 96., 171., 46., 121., 196., 71., 146., 221., - 22., 97., 172., 47., 122., 197., 72., 147., 222., 23., 98., 173., 48., 123., 198., 73., 148., 223., 24., 99., 174., 49., 124., 199., 74., 149., 224., - 25., 100., 175.,50., 125., 200.,75., 150., 225.}); + NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, + 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, + 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, + 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, + 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, + 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231}); + NDArray weightsD('c', {kH, kW, iC, mC}, {0.5340641736984253, 0.8257383108139038, 0.3279532492160797, 0.27217748761177063, 0.05432872101664543, 0.31322699785232544, 0.6599581837654114, 0.35526034235954285, 0.5765137672424316}); + NDArray weightsP('c', {1, 1, iC*mC, oC}, {0.4442146420478821, 0.3362849950790405, 0.5215804576873779, 0.5305071473121643, 0.7323054075241089, 0.5168435573577881, 0.8601323962211609, 0.2587810158729553, 0.9473239779472351, 0.39540114998817444, 0.04835261031985283, 0.8724213242530823, 0.8607604503631592, 0.8382210731506348, 0.8573186993598938, 0.6496091485023499, 0.8864102959632874, 0.14267340302467346}); + NDArray biases('c', {1,oC}, {0.8807470202445984, 0.6262521147727966}); - auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0}); + NDArray expOutput('c', {bS, oC, oH, oW}, {1.643804, 2.135067, 2.494167, 2.628944, 2.700440, 2.257452, 2.562539, 2.293667, 2.493985, 2.014933, 2.301736, 2.939066, 1.492952, + 2.026476, 1.771098, 2.013162, 1.315507, 1.289951, 2.831223, 2.196924, 2.028261, 2.024326, 2.983223, 1.809527, 1.434322, 2.513157, 1.826834, 1.608869, 1.297912, 1.212318, + 2.295934, 1.844615, 2.591148, 1.597267, 2.317755, 1.755642, 1.324064, 1.542060, 1.892052, 1.939339, 1.922781, 1.720199, 1.833396, 1.728024, 1.757968, 1.410675, 1.661960, + 2.096277, 1.178815, 1.637460, 1.254187, 1.491076, 0.968625, 0.986342, 2.116042, 1.536920, 1.504321, 1.490398, 2.136795, 1.351860, 1.148578, 1.817408, 1.327139, 1.288620, + 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515}); - input.linspace(1); - - nd4j::ops::deconv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + nd4j::ops::sconv2d op; + auto results = op.execute({&input, &weightsD, &weightsP, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); - auto output = results->at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); delete results; } @@ -705,32 +757,6 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_2) { delete result;*/ } - -TEST_F(ConvolutionTests1, TestDeconv_ff_2) { - - NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.}); - - auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); - auto weights = NDArrayFactory::create('c',{1, 1, 2, 3}, {1,3,5,2,4,6}); - auto bias = NDArrayFactory::create('c', {2}); - - input.linspace(1); - bias.linspace(1); - - nd4j::ops::deconv2d op; - - auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto output = result->at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - delete result; -} - TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { auto input = NDArrayFactory::create('c', {2, 2, 6}); auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); @@ -1338,6 +1364,40 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_6) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, depthwise_conv2d_7) { + + int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}); + NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}); + NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}); + + NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, + 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, + 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}); + + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) { @@ -2227,6 +2287,168 @@ TEST_F(ConvolutionTests1, deconv2d_test3) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test4) { + + NDArray input('c', {2, 3, 4, 4}, nd4j::DataType::FLOAT32); + NDArray weights('c', {3, 3, 5, 5}, nd4j::DataType::FLOAT32); + NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0, + 100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0, + 84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0, + 54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0, + 90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0, + 8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0, + 144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0, + 118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0, + 115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0, + 268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0, + 52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0, + 78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0, + 89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, nd4j::DataType::FLOAT32); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2,3,1,0}); + + nd4j::ops::deconv2d op; + auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = result->at(0); + // z->printShapeInfo(); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test5) { + Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; + double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,}; + NDArray exp(_expB, _expS); + + auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); + auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); + auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2,3,1,0}); + + nd4j::ops::deconv2d op; + auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{}); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(exp.isSameShape(&z)); + ASSERT_TRUE(exp.equalsTo(&z)); +} + +TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) { + + int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=8,oW=8; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1., 76., 151., 26., 101., 176., 51., 126., 201., 2., 77., 152., 27., 102., 177., 52., 127., 202., 3., 78., 153., 28., 103., 178., 53., 128., 203., + 4., 79., 154., 29., 104., 179., 54., 129., 204., 5., 80., 155., 30., 105., 180., 55., 130., 205., 6., 81., 156., 31., 106., 181., 56., 131., 206., + 7., 82., 157., 32., 107., 182., 57., 132., 207., 8., 83., 158., 33., 108., 183., 58., 133., 208., 9., 84., 159., 34., 109., 184., 59., 134., 209., + 10., 85., 160., 35., 110., 185., 60., 135., 210., 11., 86., 161., 36., 111., 186., 61., 136., 211., 12., 87., 162., 37., 112., 187., 62., 137., 212., + 13., 88., 163., 38., 113., 188., 63., 138., 213., 14., 89., 164., 39., 114., 189., 64., 139., 214., 15., 90., 165., 40., 115., 190., 65., 140., 215., + 16., 91., 166., 41., 116., 191., 66., 141., 216., 17., 92., 167., 42., 117., 192., 67., 142., 217., 18., 93., 168., 43., 118., 193., 68., 143., 218., + 19., 94., 169., 44., 119., 194., 69., 144., 219., 20., 95., 170., 45., 120., 195., 70., 145., 220., 21., 96., 171., 46., 121., 196., 71., 146., 221., + 22., 97., 172., 47., 122., 197., 72., 147., 222., 23., 98., 173., 48., 123., 198., 73., 148., 223., 24., 99., 174., 49., 124., 199., 74., 149., 224., + 25., 100., 175.,50., 125., 200.,75., 150., 225.}); + + auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0}); + + input.linspace(1); + + nd4j::ops::deconv2d op; + auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete results; +} + +TEST_F(ConvolutionTests1, deconv2d_test7) { + + NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.}); + + auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); + auto weights = NDArrayFactory::create('c',{1, 1, 2, 3}, {1,3,5,2,4,6}); + auto bias = NDArrayFactory::create('c', {2}); + + input.linspace(1); + bias.linspace(1); + + nd4j::ops::deconv2d op; + + auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto output = result->at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test8) { + + int bS=1, iH=7,iW=7, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=7,oW=7; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, + 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, + 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, + 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, + 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, + 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231, 0.192975, + 0.246897, 0.386418, 0.511541, 0.199036, 0.141631, 0.697699, 0.253631, 0.782218, 0.930099, 0.335512, 0.558808, 0.664358, 0.018851, 0.637559, 0.290430, 0.434902, + 0.842513, 0.466098, 0.381395, 0.523185, 0.990183, 0.925768, 0.643459, 0.016828, 0.918756, 0.228979, 0.006314, 0.665975, 0.190361, 0.595521, 0.698881, 0.221469, + 0.912434, 0.870822, 0.727369, 0.523972, 0.662884, 0.218841}); + + NDArray weights('c', {kH, kW, oC, iC}, {0.4195024073123932, 0.22738978266716003, 0.10093523561954498, 0.25008103251457214, 0.3183899223804474, 0.5976081490516663}); + NDArray bias('c', {1, oC}, {0.3596062958240509, 0.6866418123245239}); + + NDArray exp('c', {bS, oC, oH, oW}, {0.848190, 0.560603, 0.880509, 0.464103, 0.823376, 0.660138, 0.666382, 0.882257, 0.704650, 0.451427, 0.649734, 0.911822, 0.611581, + 0.847623, 0.568191, 0.439341, 0.710854, 0.473843, 0.927273, 0.605861, 0.724540, 0.530591, 0.804268, 0.478136, 0.602198, 0.639553, 0.669082, 0.855013, 0.678572, + 0.617800, 0.667545, 0.765899, 0.835564, 0.631733, 0.921562, 0.790830, 0.588187, 0.597934, 0.725855, 0.822259, 0.455384, 0.998167, 0.683336, 0.591897, 0.705213, + 0.748148, 0.648922, 0.484723, 0.873482, 1.368675, 0.881096, 1.169214, 0.781504, 1.433406, 1.171439, 1.348675, 1.227033, 1.256600, 0.824772, 1.051633, 1.308692, + 1.148711, 1.334007, 1.014448, 0.813336, 1.408801, 0.916766, 1.583323, 1.362920, 1.226212, 1.149715, 1.330235, 0.770671, 1.285158, 1.105632, 1.272558, 1.590159, + 1.235054, 1.201363, 1.222816, 1.623673, 1.590317, 1.322463, 1.206481, 1.466262, 0.974741, 0.922343, 1.367100, 1.087943, 1.084952, 1.586691, 1.133576, 1.405098, + 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); + + nd4j::ops::deconv2d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 836ad123b..ec67d6dbb 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -605,7 +605,6 @@ TEST_F(ConvolutionTests2, deconv3d_test5) { ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); - // output->printBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index 9180393ca..c8b6fa1d9 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -1342,6 +1342,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_1) { nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, (int*)devicePtrs[0], dimensions.size(), (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); @@ -1400,6 +1401,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_2) { nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, (int*)devicePtrs[0], dimensions.size(), (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 8dd2e7a40..7036ef77f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -2285,66 +2285,6 @@ TEST_F(DeclarableOpsTests1, IsMax4) { ASSERT_EQ(e, z); } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, CompactLaunchTests1) { - - NDArray input('c', {2, 3, 4, 4}, nd4j::DataType::FLOAT32); - NDArray weights('c', {3, 3, 5, 5}, nd4j::DataType::FLOAT32); - NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0, - 100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0, - 84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0, - 54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0, - 90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0, - 8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0, - 144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0, - 118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0, - 115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0, - 268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0, - 52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0, - 78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0, - 89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, nd4j::DataType::FLOAT32); - - input.linspace(1); - weights.linspace(1); - weights.permutei({2,3,1,0}); - - nd4j::ops::deconv2d op; - auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - auto z = result->at(0); - // z->printShapeInfo(); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, CompactLaunchTests2) { - Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; - double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,}; - NDArray exp(_expB, _expS); - - auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); - auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); - auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); - - input.linspace(1); - weights.linspace(1); - weights.permutei({2,3,1,0}); - - nd4j::ops::deconv2d op; - auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{}); - - ASSERT_EQ(ND4J_STATUS_OK, result); - - ASSERT_TRUE(exp.isSameShape(&z)); - ASSERT_TRUE(exp.equalsTo(&z)); -} - - //////////////////////////////////////////////////////////////////// // TEST_F(DeclarableOpsTests1, sru_old_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index f0ae83168..13c8db009 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1470,6 +1470,169 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { delete results; } +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { + + NDArray input = NDArrayFactory::create('c', {2,3,4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, + 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., + 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, + 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, + 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, + 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + + //result->printIndexedBuffer("Resized to 10x10"); + //expected.printIndexedBuffer("Expect for 10x10"); +// result->printShapeInfo("Output shape"); +// expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { + + NDArray input = NDArrayFactory::create('c', {2, 5,5,3}, {0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, + 0.9234f, 0.0856f, 0.7938f, + 0.6591f, 0.5555f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, + 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f, 0.7028f, + 0.7679f, 0.5373f, 0.7234f, + 0.2690f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.7494f, + 0.0755f, 0.6245f, 0.3491f, + 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, + 0.3574f, 0.1704f, 0.8395f, + 0.5468f, 0.0744f, 0.9011f, + 0.6574f, 0.4124f, 0.2445f, + 0.4248f, 0.5219f, 0.6952f, + 0.4900f, 0.2158f, 0.9549f, + 0.1386f, 0.1544f, 0.5365f, + 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, + 0.2974f, 0.6636f, 0.3808f, + 0.8664f, 0.1896f, 0.7530f, + 0.7215f, 0.6612f, 0.7270f, + 0.5704f, 0.2666f, 0.7453f, + 0.0444f, 0.3024f, 0.4850f, + 0.7982f, 0.0965f, 0.7843f, + 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, + 0.8594f, 0.4599f, 0.6714f, + 0.2744f, 0.1981f, 0.4143f, + 0.7821f, 0.3505f, 0.5040f, + 0.1180f, 0.8307f, 0.1817f, + 0.8442f, 0.5074f, 0.4471f, + 0.5105f, 0.6666f, 0.2576f, + 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, + 0.1210f, 0.2576f, 0.0769f, + 0.4643f, 0.1628f, 0.2026f, + 0.3774f, 0.0506f, 0.3462f, + 0.5720f, 0.0838f, 0.4228f, + 0.0588f, 0.5362f, 0.4756f, + 0.2530f, 0.1778f, 0.0751f, + 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, + 0.5171f, 0.1744f, 0.3487f}); + + NDArray expected = NDArrayFactory::create('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, + 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., + 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, + 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, + 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, + 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); + //input.linspace(1); + + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + + result->printIndexedBuffer("Resized to 9x9"); + //expected.printIndexedBuffer("Expect for 10x10"); + result->printShapeInfo("Output shape"); +// expected.printShapeInfo("Expect shape"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { @@ -1844,6 +2007,53 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { NDArray* result = results->at(0); +// result->printIndexedBuffer("Resized to 4x5"); +// expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {4, 5, 4}, { 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24 + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_nearest_neighbor op; + auto results = op.execute({&input}, {}, {4, 5}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + //result->printIndexedBuffer("Resized to 4x5"); //expected.printIndexedBuffer("Expect for 4x5"); ASSERT_TRUE(expected.isSameShape(result)); @@ -2040,7 +2250,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { NDArray cropSize = NDArrayFactory::create({1, 1}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {2.5f}); + NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {2.5f}); nd4j::ops::crop_and_resize op; auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {}); @@ -2049,6 +2259,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { auto result = results->at(0); // result->printIndexedBuffer("Cropped and Resized"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2082,7 +2293,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { - NDArray images ('c', {1,2,2,1}, {1,2,3,4}); + NDArray images ('c', {1,2,2,1}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxI('c', {1}, {0}, nd4j::DataType::INT64); NDArray cropSize = NDArrayFactory::create({3, 3}); @@ -2106,7 +2317,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { - NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}); + NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxI('c', {1}, {0}, nd4j::DataType::INT32); NDArray cropSize = NDArrayFactory::create({3, 3}); @@ -2130,7 +2341,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { - NDArray images('c', {1, 100, 100, 3}); + NDArray images('c', {1, 100, 100, 3}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxI('c', {2}, {1,1}, nd4j::DataType::INT32); NDArray cropSize = NDArrayFactory::create({10, 10}); @@ -2155,23 +2366,23 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { NDArray images = NDArrayFactory::create('c', {2,4,5,3}); NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, { - 0. , 0. , 1. , 1. , 0.1, 0.2, 0.9, 0.8, - 0.3, 0.3, 0.7, 0.7, 0.4, 0.4, 0.6, 0.6 + 0.f , 0.f , 1.f , 1.f , 0.1f, 0.2f, 0.9f, 0.8f, + 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f }); - NDArray colors = NDArrayFactory::create('c', {2, 3}, {201., 202., 203., 127., 128., 129.}); + NDArray colors = NDArrayFactory::create('c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); NDArray expected = NDArrayFactory::create('c', {2,4,5,3}, { - 127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203., - 127., 128., 129., 19., 20., 21., 22., 23., 24., 127., 128., 129., 201., 202., 203., - 127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203., - 201., 202., 203., 201. ,202. ,203., 201., 202., 203., 201., 202., 203., 201., 202., 203., + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, - 61., 62., 63., 201., 202., 203., 201., 202., 203., 70., 71., 72., 73., 74., 75., - 76., 77., 78., 127., 128., 129., 127., 128., 129., 85., 86., 87., 88., 89., 90., - 91., 92., 93., 201., 202., 203., 201., 202., 203., 100., 101., 102., 103., 104., 105., - 106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120. + 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f, + 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, + 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f }); images.linspace(1.); nd4j::ops::draw_bounding_boxes op; @@ -2180,6 +2391,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); + result->syncToHost(); // result->printBuffer("Bounded boxes"); // expected.printBuffer("Bounded expec"); ASSERT_TRUE(expected.isSameShapeStrict(result)); @@ -2191,20 +2403,20 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { NDArray images = NDArrayFactory::create('c', {1,9,9,1}); - NDArray boxes = NDArrayFactory::create('c', {1, 1, 4}, {0.2, 0.2, 0.7, 0.7}); - NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95}); + NDArray boxes = NDArrayFactory::create('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f}); + NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95f}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); NDArray expected = NDArrayFactory::create('c', {1,9,9,1}, { - 1.1 , 2.1, 3.1 , 4.1 , 5.1 , 6.1 , 7.1 , 8.1 , 9.1 , - 10.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 16.1 , 17.1 , 18.1 , - 19.1 , 0.95, 21.1, 22.1, 23.1, 0.95, 25.1 , 26.1 , 27.1 , - 28.1 , 0.95, 30.1, 31.1, 32.1, 0.95, 34.1 , 35.1 , 36.1 , - 37.1 , 0.95, 39.1, 40.1, 41.1, 0.95, 43.1 , 44.1 , 45.1 , - 46.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 52.1 , 53.1 , 54.1 , - 55.1 , 56.1, 57.1 , 58.1 , 59.1 , 60.1 , 61.1 , 62.1 , 63.1 , - 64.1 , 65.1, 66.1 , 67.1 , 68.1 , 69.1 , 70.1 , 71.1 , 72.1 , - 73.1 , 74.1, 75.1 , 76.1 , 77.1 , 78.1 , 79.1 , 80.1 , 81.1 }); + 1.1f , 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f , 8.1f , 9.1f , + 10.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f , 17.1f , 18.1f , + 19.1f , 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f , 26.1f , 27.1f , + 28.1f , 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f , 35.1f , 36.1f , + 37.1f , 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f , 44.1f , 45.1f , + 46.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f , 53.1f , 54.1f , + 55.1f , 56.1f, 57.1f, 58.1f, 59.1f , 60.1f, 61.1f , 62.1f , 63.1f , + 64.1f , 65.1f, 66.1f, 67.1f, 68.1f , 69.1f, 70.1f , 71.1f , 72.1f , + 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f }); images.linspace(1.1); nd4j::ops::draw_bounding_boxes op; auto results = op.execute({&images, &boxes, &colors}, {}, {}); @@ -2221,6 +2433,61 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { + NDArray images = NDArrayFactory::create('c', {2,5,5,1}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, + 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, 0.7234f, + 0.2690f, 0.0062f, 0.0327f, 0.0644f, 0.8428f, 0.7494f, + 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f}); + + NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f, + 0.7876f, 0.8945f, 0.4638f, 0.7157f}); + NDArray colors = NDArrayFactory::create('c', {1, 2}, {0.9441f, 0.5957f}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); +// NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { +// 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, +// 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, +// 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, +// 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, +// 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, +// 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f }); + NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.9441f, 0.9441f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, + 0.6765f, 0.18f , 0.675f , 0.2246f, 0.0509f, + + 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, + 0.7234f, 0.269f , 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, + 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f}); + nd4j::ops::draw_bounding_boxes op; + auto results = op.execute({&images, &boxes, &colors}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printBuffer("Boxes3 output"); +// expected.printBuffer("Boxes3 expect"); + +// result->syncToHost(); +// result->printBuffer("Bounded boxes 2"); +// expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { @@ -2376,6 +2643,55 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { delete results; } +/* public void testFakeQuantAgainstTF_1() { + INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); + INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5); + INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5); + + INDArray out = Nd4j.createUninitialized(x.shape()); + val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out); + + INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, + 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, + 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); + + assertEquals(expected, out); + }*/ +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { + NDArray x = NDArrayFactory::create('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); +// NDArray exp = NDArrayFactory::create('c', {3, 5},{ +// 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, +// 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, +// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f +// }); + + NDArray exp = NDArrayFactory::create('c', {3,5}, {0.77700233, 0.596913, 0.72314, 0.23104, 0.50982356, + 0.17930824, 0.50528157, 0.86846, 0.34995764, 0.50982356, + 0.08735529, 0.596913, 0.6574, 0.34995764, 0.15974471}); + NDArray min = NDArrayFactory::create('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + // x.linspace(-60.); + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printBuffer("Quantized per channels 5"); +// exp.printBuffer("Quantized per channest E"); +// auto diff = *result - exp; +// diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, batchnorm_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index d077f886d..ac56c496f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -477,6 +477,617 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { delete results; } +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 7, 7, 1}, { + 1, 2.1, 3.15, 4.2, 5.15, 6.1, 7, + 8, 9.1, 10., 11, 12.9, 13.1, 14, + 15, 16., 17., 18, 19, 20., 21, + 22, 23., 24., 25, 26, 27, 28, + 30, 31, 32, 33, 34., 35, 36, + 37, 38, 39, 40, 41., 42, 43, + 44, 45, 46, 47, 48., 49, 50 + }); + NDArray expected = NDArrayFactory::create('c', {1, 30, 30, 1}, { + 1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 , + 2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 , + 3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 , + 5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 , + 6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 , + 2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 , + 3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 , + 5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 , + 6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 , + 7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 , + 3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 , + 5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 , + 6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 , + 8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 , + 9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 , + 5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 , + 6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 , + 8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 , + 10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 , + 10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 , + 7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 , + 8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 , + 9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 , + 12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 , + 12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 , + 9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 , + 10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 , + 12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 , + 14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 , + 15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 , + 10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 , + 12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 , + 13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 , + 15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 , + 16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 , + 12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 , + 13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 , + 14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 , + 16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 , + 17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 , + 13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 , + 15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 , + 16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 , + 18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 , + 19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 , + 15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 , + 17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 , + 18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 , + 20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 , + 21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 , + 17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 , + 18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 , + 20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 , + 21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 , + 23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 , + 18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 , + 20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 , + 21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 , + 22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 , + 24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 , + 20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 , + 21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 , + 22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 , + 24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 , + 25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 , + 22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 , + 23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 , + 25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 , + 26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 , + 28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 , + 24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 , + 25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 , + 27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 , + 28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 , + 30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 , + 26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 , + 27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 , + 28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 , + 30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 , + 31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 , + 27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 , + 28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 , + 30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 , + 31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 , + 33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 , + 29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 , + 31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 , + 32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 , + 33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 , + 35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 , + 31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 , + 33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 , + 34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 , + 36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 , + 37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 , + 33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 , + 34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 , + 36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 , + 37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 , + 38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 , + 34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 , + 35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 , + 37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 , + 38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 , + 40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 , + 36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 , + 37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 , + 38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 , + 40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 , + 41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 , + 38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 , + 39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 , + 41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 , + 42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 , + 43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 , + 40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 , + 41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 , + 42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 , + 44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 , + 45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 , + 41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 , + 43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 , + 44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 , + 46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 , + 47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 , + 43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 , + 44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 , + 45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 , + 47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 , + 48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 , + 44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 , + 45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 , + 47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 , + 48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 , + 49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 , + 44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 , + 46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 , + 47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 , + 49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 , + 50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 , + 44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 , + 46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 , + 47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 , + 48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 , + 50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 , + 44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 , + 45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 , + 46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 , + 48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 , + 49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057}); + + auto size = NDArrayFactory::create({30, 30}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 30x30"); +// expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); + NDArray expected = NDArrayFactory::create('c', {2, 10, 8, 3}, { + 1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5, + 6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. , + 12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375, + 8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625, + 14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14., + 15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5, + 19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125, + 23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23., + 24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125, + 28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875, + 27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32., + 33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125, + 31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5, + 36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41., + 42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875, + 40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125, + 45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125, + 46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625, + 50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625, + 54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53., + 54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125, + 58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375, + 52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125, + 58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625, + 61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 , + 66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. , + 72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375, + 68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625, + 74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625, + 77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875, + 76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79., + 80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83., + 84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81., + 80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5, + 84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125, + 88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125, + 85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88., + 89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92., + 93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96., + 94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875, + 93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5, + 97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125, + 100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97., + 98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101., + 102. ,101.5 ,102.5 ,103.5, 103. ,104., 105., + 104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125, + 107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375, + 107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625, + 110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125, + 114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110., + 111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114., + 113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125, + 117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125, + 120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375, + 113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125, + 117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125, + 121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f; + input.linspace(1); + auto size = NDArrayFactory::create({10, 8}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 10x8"); +// expected.printBuffer("Expect for 10x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { + 1. ,2. ,3. ,4., + 2.625 ,3.625 ,4.625 ,5.625, + 5. ,6. ,7. ,8., + 7.375 ,8.375 ,9.375, 10.375, + 9. ,10. ,11. ,12., + 9.375 ,10.375 ,11.375 ,12.375, + + 5.875 ,6.875 ,7.875 , 8.875 , + 7.5 ,8.5 ,9.5 , 10.5 , + 9.875 ,10.875 ,11.875, 12.875, + 12.25 ,13.25 ,14.25 , 15.25 , + 13.875 ,14.875 ,15.875, 16.875, + 14.25 ,15.25 ,16.25 , 17.25 , + + 13. ,14. ,15. ,16., + 14.625 ,15.625 ,16.625 ,17.625, + 17. ,18. ,19. ,20., + 19.375 ,20.375 ,21.375 ,22.375, + 21. ,22. ,23. ,24., + 21.375 ,22.375 ,23.375 ,24.375, + + 20.125 ,21.125 ,22.125 ,23.125, + 21.75 ,22.75 ,23.75 ,24.75, + 24.125 ,25.125 ,26.125 ,27.125, + 26.5 ,27.5 ,28.5 ,29.5, + 28.125 ,29.125 ,30.125 ,31.125, + 28.5 ,29.5 ,30.5 ,31.5, + + 25. , 26. , 27. , 28., + 26.625 ,27.625 ,28.625 ,29.625, + 29. ,30. ,31. ,32., + 31.375 ,32.375 ,33.375 ,34.375, + 33. ,34. ,35. ,36., + 33.375 ,34.375 ,35.375 ,36.375, + + 26.125, 27.125, 28.125, 29.125, + 27.75 ,28.75 ,29.75 ,30.75, + 30.125 ,31.125 ,32.125 ,33.125, + 32.5 ,33.5 ,34.5 ,35.5, + 34.125 ,35.125 ,36.125 ,37.125, + 34.5 ,35.5 ,36.5 ,37.5 + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 6x6"); +// expected.printBuffer("Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 8, 3}, { + 1. , 2. , 3. , + 2.21875 ,3.21875 ,4.21875, + 4. ,5. ,6. , + 5.5 ,6.5 ,7.5 , + 7. ,8. ,9. , + 8.78125 ,9.78125, 10.78125, + 10. ,11., 12. , + 10.28125 ,11.28125, 12.28125, + + 5.875 , 6.875 , 7.875 , + 7.09375 , 8.09375 , 9.09375, + 8.875 , 9.875 ,10.875 , + 10.375 ,11.375 ,12.375 , + 11.875 ,12.875 ,13.875 , + 13.65625 ,14.65625 ,15.65625, + 14.875 ,15.875 ,16.875 , + 15.15625 ,16.15625 ,17.15625, + + 13., 14., 15., + 14.21875 ,15.21875 ,16.21875, + 16. ,17. ,18. , + 17.5 ,18.5 ,19.5 , + 19. ,20. ,21. , + 20.78125 ,21.78125 ,22.78125, + 22. ,23. ,24. , + 22.28125 ,23.28125 ,24.28125, + + 20.125 , 21.125 , 22.125, + 21.34375 ,22.34375 ,23.34375, + 23.125 ,24.125 ,25.125 , + 24.625 ,25.625 ,26.625 , + 26.125 ,27.125 ,28.125 , + 27.90625 ,28.90625 ,29.90625, + 29.125 ,30.125 ,31.125 , + 29.40625 ,30.40625 ,31.40625, + + 25. ,26. ,27. , + 26.21875 ,27.21875 ,28.21875, + 28. ,29. ,30. , + 29.5 ,30.5 ,31.5 , + 31. ,32. ,33. , + 32.78125 ,33.78125 ,34.78125, + 34. ,35. ,36. , + 34.28125 ,35.28125 ,36.28125, + + 26.125 ,27.125 , 28.125 , + 27.34375 ,28.34375 ,29.34375, + 29.125 ,30.125 ,31.125 , + 30.625 ,31.625 ,32.625 , + 32.125 ,33.125 ,34.125 , + 33.90625 ,34.90625 ,35.90625, + 35.125 ,36.125 ,37.125 , + 35.40625 ,36.40625 ,37.40625 }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 8}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 6x8"); +// expected.printBuffer("Expect for 6x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { + + NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 3}, { + 1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. , + 6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 , + 9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 , + 5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 , + 10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625, + 14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625, + 13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. , + 18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125, + 21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125, + 19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. , + 24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125, + 27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125, + 25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. , + 30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125, + 33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125, + 32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 , + 37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625, + 40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625, + 37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. , + 42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125, + 45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125, + 38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 , + 43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625, + 46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625, + }); + input.linspace(1); + auto size = NDArrayFactory::create({8, 8}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 8x8"); +// expected.printBuffer("Expect for 8x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { + + NDArray input = NDArrayFactory::create('c', {7, 7, 1}, { + 1, 2.1, 3.15, 4.2, 5.15, 6.1, 7, + 8, 9.1, 10., 11, 12.9, 13.1, 14, + 15, 16., 17., 18, 19, 20., 21, + 22, 23., 24., 25, 26, 27, 28, + 30, 31, 32, 33, 34., 35, 36, + 37, 38, 39, 40, 41., 42, 43, + 44, 45, 46, 47, 48., 49, 50 + }); + + NDArray expected = NDArrayFactory::create('c', {30, 30, 1}, { + 1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 , + 2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 , + 3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 , + 5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 , + 6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 , + 2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 , + 3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 , + 5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 , + 6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 , + 7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 , + 3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 , + 5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 , + 6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 , + 8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 , + 9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 , + 5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 , + 6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 , + 8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 , + 10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 , + 10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 , + 7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 , + 8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 , + 9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 , + 12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 , + 12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 , + 9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 , + 10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 , + 12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 , + 14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 , + 15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 , + 10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 , + 12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 , + 13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 , + 15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 , + 16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 , + 12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 , + 13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 , + 14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 , + 16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 , + 17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 , + 13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 , + 15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 , + 16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 , + 18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 , + 19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 , + 15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 , + 17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 , + 18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 , + 20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 , + 21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 , + 17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 , + 18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 , + 20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 , + 21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 , + 23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 , + 18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 , + 20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 , + 21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 , + 22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 , + 24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 , + 20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 , + 21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 , + 22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 , + 24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 , + 25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 , + 22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 , + 23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 , + 25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 , + 26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 , + 28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 , + 24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 , + 25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 , + 27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 , + 28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 , + 30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 , + 26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 , + 27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 , + 28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 , + 30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 , + 31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 , + 27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 , + 28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 , + 30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 , + 31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 , + 33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 , + 29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 , + 31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 , + 32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 , + 33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 , + 35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 , + 31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 , + 33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 , + 34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 , + 36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 , + 37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 , + 33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 , + 34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 , + 36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 , + 37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 , + 38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 , + 34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 , + 35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 , + 37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 , + 38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 , + 40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 , + 36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 , + 37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 , + 38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 , + 40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 , + 41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 , + 38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 , + 39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 , + 41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 , + 42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 , + 43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 , + 40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 , + 41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 , + 42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 , + 44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 , + 45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 , + 41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 , + 43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 , + 44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 , + 46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 , + 47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 , + 43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 , + 44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 , + 45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 , + 47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 , + 48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 , + 44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 , + 45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 , + 47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 , + 48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 , + 49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 , + 44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 , + 46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 , + 47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 , + 49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 , + 50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 , + 44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 , + 46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 , + 47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 , + 48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 , + 50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 , + 44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 , + 45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 , + 46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 , + 48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 , + 49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057}); + + auto size = NDArrayFactory::create({30, 30}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 30x30"); +// expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 97e7d2d91..ff554f837 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -551,4 +551,34 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { auto temp1 = ft.reshape('f', {bS, nIn}); auto temp2 = temp1 * cLast; } +} + +TEST_F(DeclarableOpsTests15, test_empty_increasing_1) { + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + + nd4j::ops::is_strictly_increasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(true, z.e(0)); +} + +TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + + nd4j::ops::is_non_decreasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(true, z.e(0)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 1dc2c8e48..085127e74 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -1798,6 +1799,7 @@ TEST_F(HelpersTests1, tensordot_test_6) { // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); + // c.printBuffer(); ASSERT_TRUE(c.isSameShape(expected)); ASSERT_TRUE(c.equalsTo(expected)); diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index aa75ea1ab..a6ca56fd4 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -426,6 +426,26 @@ TEST_F(JavaInteropTests, Test_FastPath_Validation_2) { ASSERT_NE(Status::OK(), status); } +TEST_F(JavaInteropTests, Test_FastPath_Validation_3) { + auto x = NDArrayFactory::create('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + + auto min = NDArrayFactory::create({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + auto max = NDArrayFactory::create({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + auto z = NDArrayFactory::create('c', {3, 5}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo()); + ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + ASSERT_ANY_THROW(op.execute(&ctx)); +} + TEST_F(JavaInteropTests, Test_empty_cast_1) { auto x = NDArrayFactory::create('c', {1, 0, 2}); auto z = NDArrayFactory::create('c', {1, 0, 2}); diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index f48ee54f6..ffcd5759e 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -674,6 +674,7 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, 0, tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index 747ecc183..0f3cab509 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -1891,7 +1891,7 @@ TEST_F(NDArrayTest, TestMMulMultiDim) { ASSERT_TRUE(result->isSameShape(&expected)); //result->printShapeInfo("result shape"); - //result->printBuffer("result buffer"); + // result->printBuffer("result buffer"); ASSERT_TRUE(result->equalsTo(&expected)); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 95b3027cc..1846fc397 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -202,6 +202,7 @@ printf("Unsupported for cuda now.\n"); nullptr, nullptr, exp.buffer(), exp.shapeInfo(), nullptr, nullptr, + nullptr, dimension.buffer(), dimension.shapeInfo(), nullptr, nullptr); ASSERT_TRUE(exp.e(1) && !exp.e(0)); diff --git a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp index 998b8164b..68c68aafb 100644 --- a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp @@ -61,8 +61,10 @@ public: TEST_F(PerformanceTests, test_maxpooling2d_1) { std::vector valuesX; - auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); - auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); + auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); x.linspace(1.0f); Nd4jLong k = 5; diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index dfb685e22..051c65988 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -33,6 +34,7 @@ #include #include #include +#include #include #include @@ -274,4 +276,69 @@ TEST_F(PlaygroundTests, test_relubp_1) { nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); } -*/ + +////////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, my) { + + int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; + int oD,oH,oW; + + nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + + printf("!!%i, %i, %i\n", oD,oH,oW); + + NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE); + NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE); + + col = 3.77; + vol = -10.33; + + auto variableSpace = new VariableSpace(); + auto block = new Context(1, variableSpace, false); // not-in-place + + auto timeStart = std::chrono::system_clock::now(); + nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); + auto timeEnd = std::chrono::system_clock::now(); + auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + + printf("time: %i \n", time); + + delete block; + delete variableSpace; +} + +TEST_F(PlaygroundTests, my) { + + int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; + int oD,oH,oW; + + // nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + nd4j::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0); + + printf("!!%i, %i, %i\n", oD,oH,oW); + + // NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE); + // NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE); + NDArray col('c', {bS, iC, kH, kW, iH, iW}, nd4j::DataType::DOUBLE); + NDArray im('c', {bS, iC, oH, oW}, nd4j::DataType::DOUBLE); + + col = 3.77; + // vol = -10.33; + im = -10.33; + + auto variableSpace = new VariableSpace(); + auto block = new Context(1, variableSpace, false); // not-in-place + + auto timeStart = std::chrono::system_clock::now(); + // nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); + nd4j::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW); + auto timeEnd = std::chrono::system_clock::now(); + auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + + printf("time: %i \n", time); + + delete block; + delete variableSpace; +} + +*/ \ No newline at end of file diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 315839dba..07cae9ae3 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -42,7 +42,7 @@ if ("${BUILD_MKLDNN}") set(HAVE_MKLDNN 1) add_definitions("-DHAVE_MKLDNN") include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR}) - set(MKLDNN mkldnn) + set(MKLDNN dnnl) endif() # Download and unpack flatbuffers at configure time diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 8c80e3bb4..655e4159f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -255,6 +255,11 @@ public abstract class DifferentialFunction { } else { try { + //Edge case: we store float fields as doubles, rather than introduce an extra property + if(target.getType() == float.class && value instanceof Double){ + value = ((Double) value).floatValue(); + } + target.set(this,value); } catch (IllegalAccessException e) { throw new RuntimeException("Error setting property for function " + getClass().getName(), e); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 52e725191..abea31459 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -149,36 +149,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.Concat; -import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; -import org.nd4j.linalg.api.ops.impl.shape.Cross; -import org.nd4j.linalg.api.ops.impl.shape.Diag; -import org.nd4j.linalg.api.ops.impl.shape.DiagPart; -import org.nd4j.linalg.api.ops.impl.shape.ExpandDims; -import org.nd4j.linalg.api.ops.impl.shape.Gather; -import org.nd4j.linalg.api.ops.impl.shape.GatherNd; -import org.nd4j.linalg.api.ops.impl.shape.MergeAvg; -import org.nd4j.linalg.api.ops.impl.shape.MergeMax; -import org.nd4j.linalg.api.ops.impl.shape.MeshGrid; -import org.nd4j.linalg.api.ops.impl.shape.OneHot; -import org.nd4j.linalg.api.ops.impl.shape.OnesLike; -import org.nd4j.linalg.api.ops.impl.shape.ParallelStack; -import org.nd4j.linalg.api.ops.impl.shape.Permute; -import org.nd4j.linalg.api.ops.impl.shape.Rank; -import org.nd4j.linalg.api.ops.impl.shape.ReductionShape; -import org.nd4j.linalg.api.ops.impl.shape.Repeat; -import org.nd4j.linalg.api.ops.impl.shape.Reshape; -import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; -import org.nd4j.linalg.api.ops.impl.shape.Size; -import org.nd4j.linalg.api.ops.impl.shape.SizeAt; -import org.nd4j.linalg.api.ops.impl.shape.Slice; -import org.nd4j.linalg.api.ops.impl.shape.Squeeze; -import org.nd4j.linalg.api.ops.impl.shape.Stack; -import org.nd4j.linalg.api.ops.impl.shape.StridedSlice; -import org.nd4j.linalg.api.ops.impl.shape.Tile; -import org.nd4j.linalg.api.ops.impl.shape.Transpose; -import org.nd4j.linalg.api.ops.impl.shape.Unstack; -import org.nd4j.linalg.api.ops.impl.shape.ZerosLike; +import org.nd4j.linalg.api.ops.impl.shape.*; import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; @@ -260,40 +231,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1; -import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid; -import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p; -import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid; -import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; @@ -372,6 +310,15 @@ public class DifferentialFunctionFactory { return new ZerosLike(name, sameDiff(), input).outputVariable(); } + public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { + return create(name, shape, 'c', initialize, dataType); + } + + public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) { + validateDifferentialFunctionsameDiff(shape); + return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable(); + } + public SDVariable onesLike(String name, SDVariable input, DataType dataType) { validateDifferentialFunctionsameDiff(input); return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); @@ -1464,6 +1411,9 @@ public class DifferentialFunctionFactory { return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); } + public SDVariable mishDerivative(SDVariable iX) { + return new MishDerivative(sameDiff(), iX, false).outputVariable(); + } public SDVariable swish(SDVariable iX) { return new Swish(sameDiff(), iX, false).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index e89047ee4..5b4cb497b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -479,6 +479,8 @@ public class FlatBuffersMapper { } else if (v instanceof Number) { if (v instanceof Double) { d = new double[]{(Double) v}; + } else if (v instanceof Float){ + d = new double[]{(Float) v}; } else if (v instanceof Integer) { i = new int[]{(Integer) v}; } else if (v instanceof Long) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index f76c42c50..33d983f23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -417,7 +417,7 @@ public class LegacyOpMapper { case 4: return ScalarGreaterThanOrEqual.class; case 5: - return ScalarLessThanOrEqual.class; + return MatchCondition.class; case 6: return ScalarNotEquals.class; case 7: @@ -428,6 +428,8 @@ public class LegacyOpMapper { return ScalarXor.class; case 10: return ScalarNot.class; + case 11: + return ScalarLessThanOrEqual.class; default: throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index bcc2b3c8b..7b19406ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -46,6 +46,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, + org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, org.nd4j.linalg.api.ops.custom.KnnMinDistance.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.Flatten.class, @@ -64,6 +65,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo.class, + org.nd4j.linalg.api.ops.impl.shape.Create.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual.class, @@ -584,7 +586,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class, - org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class ); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 233609b19..8605467cc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -267,7 +267,7 @@ public class TFGraphMapper { https://github.com/eclipse/deeplearning4j/issues/8285 */ DifferentialFunction dfInstance = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName); - Preconditions.checkState(dfInstance != null, "Could not find class for TF Ops: {}", opName); + Preconditions.checkState(dfInstance != null, "Could not find class for TF Ops: %s", opName); DifferentialFunction df; try { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java index 86ba0dfef..e166461e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java @@ -28,7 +28,7 @@ import org.nd4j.linalg.activations.impl.*; public enum Activation { CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6, RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH, - THRESHOLDEDRELU, GELU; + THRESHOLDEDRELU, GELU, MISH; /** * Creates an instance of the activation function @@ -77,6 +77,8 @@ public enum Activation { return new ActivationThresholdedReLU(); case GELU: return new ActivationGELU(); + case MISH: + return new ActivationMish(); default: throw new UnsupportedOperationException("Unknown or not supported activation function: " + this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java new file mode 100644 index 000000000..b789a2925 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java @@ -0,0 +1,59 @@ +/* ***************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.activations.impl; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.val; +import org.nd4j.linalg.activations.BaseActivationFunction; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish; +import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +/** + * https://arxiv.org/ftp/arxiv/papers/1908/1908.08681.pdf + */ +@EqualsAndHashCode(callSuper = false) +@Getter +public class ActivationMish extends BaseActivationFunction { + + @Override + public INDArray getActivation(INDArray in, boolean training) { + Nd4j.getExecutioner().execAndReturn(new Mish(in)); + return in; + } + + @Override + public Pair backprop(INDArray in, INDArray epsilon) { + assertShape(in, epsilon); + + val dLdZ = Nd4j.getExecutioner().exec(new MishDerivative(in, in)); + + dLdZ.muli(epsilon); + return new Pair<>(in, null); + } + + @Override + public String toString() { + return "mish"; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 16931e434..77b946559 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1864,7 +1864,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray match(INDArray comp, Condition condition) { - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp,condition)); + // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition + Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); + Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 65bdb5c1e..30f3d0bf5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -25,6 +25,7 @@ import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -386,4 +387,15 @@ public abstract class BaseOp extends DifferentialFunction implements Op { y = null; z = null; } + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index 80f29e577..2d0ac235f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Collections; +import java.util.List; + public class AdjustContrast extends BaseAdjustContrast { public AdjustContrast() {super();} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 5adfcbafd..9ebb3ea6f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Collections; +import java.util.List; + public class AdjustContrastV2 extends BaseAdjustContrast { public AdjustContrastV2() {super();} @@ -25,6 +29,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast { @Override public String tensorflowName() { - return "AdjustContrastV2"; + return "AdjustContrastv2"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java index cadef80e6..25cddd741 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Collections; +import java.util.List; + public abstract class BaseAdjustContrast extends DynamicCustomOp { public BaseAdjustContrast() { } @@ -22,4 +26,11 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { super("", sameDiff, vars); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index 43bff11e6..ebae33fce 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -3,16 +3,18 @@ package org.nd4j.linalg.api.ops.custom; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; -import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.util.Collections; +import java.util.List; import java.util.Map; public class BitCast extends DynamicCustomOp { @@ -58,4 +60,11 @@ public class BitCast extends DynamicCustomOp { public String tensorflowName() { return "Bitcast"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index 400830ec3..801384bfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -3,8 +3,14 @@ package org.nd4j.linalg.api.ops.custom; import org.apache.commons.math3.analysis.function.Divide; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.Shape; + +import java.util.Collections; +import java.util.List; public class DivideNoNan extends DynamicCustomOp { public DivideNoNan() { @@ -29,4 +35,12 @@ public class DivideNoNan extends DynamicCustomOp { public String tensorflowName() { return "DivNoNan"; } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes); + + DataType z = Shape.pickPairwiseDataType(dataTypes.get(0), dataTypes.get(1)); + return Collections.singletonList(z); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index 4c672a66c..57551c84c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -2,9 +2,14 @@ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Collections; +import java.util.List; + public class DrawBoundingBoxes extends DynamicCustomOp { public DrawBoundingBoxes() {} @@ -26,7 +31,14 @@ public class DrawBoundingBoxes extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "DrawBoundingBoxes"; + public String[] tensorflowNames() { + return new String[]{"DrawBoundingBoxes", "DrawBoundingBoxesV2"}; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index 303ac8458..ef150843d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Collections; +import java.util.List; + public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { public FakeQuantWithMinMaxVarsPerChannel() {} @@ -33,4 +37,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { public String tensorflowName() { return "FakeQuantWithMinMaxVarsPerChannel"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 inputs, got %s", inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java index d3007427d..5454d8953 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseBroadcastOp; @@ -65,11 +66,6 @@ public class BiasAddGrad extends DynamicCustomOp { return "BiasAddGrad"; } - @Override - public String tensorflowName() { - return "BiasAddGrad"; - } - @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected 3 input data types for %s, got %s", getClass(), inputDataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAddOp.java index c833d1786..805de3197 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAddOp.java @@ -61,14 +61,4 @@ public class BroadcastAddOp extends BaseBroadcastOp { public List doDiff(List f1) { return null; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "BroadcastAdd"; - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastGradientArgs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastGradientArgs.java index c90c7e1d3..b3aa5f244 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastGradientArgs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastGradientArgs.java @@ -61,14 +61,4 @@ public class BroadcastGradientArgs extends BaseBroadcastOp { public List doDiff(List f1) { return null; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No op name found for " + opName()); - } - - @Override - public String tensorflowName() { - return "BroadcastGradientArgs"; - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java index e9ee1db2c..dc9017582 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java @@ -64,7 +64,7 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp { @Override public int opNum() { - return 5; + return 11; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 0fee6c238..9af5ede2e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -51,7 +51,8 @@ public class CropAndResize extends DynamicCustomOp { } public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, - @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue){ + @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue, + INDArray output){ super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null); Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image); Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes); @@ -60,6 +61,7 @@ public class CropAndResize extends DynamicCustomOp { this.method = method; this.extrapolationValue = extrapolationValue; addArgs(); + outputArguments.add(output); } @Override @@ -89,8 +91,6 @@ public class CropAndResize extends DynamicCustomOp { } protected void addArgs() { - iArguments.clear(); - tArguments.clear(); addIArgument(method == Method.BILINEAR ? 0 : 1); addTArgument(extrapolationValue); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index d7161cf5f..25463030c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -53,7 +53,7 @@ public class NonMaxSuppression extends DynamicCustomOp { @Override public String[] tensorflowNames() { - return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; + return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2", "NonMaxSuppressionV3"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java index ab2984969..c127b76bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java @@ -30,6 +30,7 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -231,11 +232,6 @@ public class Pooling2D extends DynamicCustomOp { return "Pooling"; } - @Override - public String tensorflowName() { - return "Pooling2D"; - } - @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index 468224939..b2823e712 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -64,11 +64,6 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { return "sigm_cross_entropy_loss"; } - @Override - public String tensorflowName() { - return "sigmoid_cross_entropy"; - } - @Override public List doDiff(List grad){ //No external gradient diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index 7afdf5166..c7aef3c62 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -73,16 +73,6 @@ public class Moments extends DynamicCustomOp { return "moments"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "moments"; - } - @Override public List doDiff(List grad){ SDVariable dLdMean = grad.get(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java index 36b4600c9..945cd505d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java @@ -69,16 +69,6 @@ public class NormalizeMoments extends DynamicCustomOp { return "normalize_moments"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "normalize_moments"; - } - @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected 3 input datatypes for %s, got %s", getClass(), inputDataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index 6ce3c1eef..c74acf734 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.bool; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceBoolOp; @@ -64,12 +65,6 @@ public class IsInf extends BaseReduceBoolOp { return "HasInf"; } - @Override - public String tensorflowName() { - return "HasInf"; - } - - @Override public List doDiff(List i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index 2d25391a9..611219d3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.bool; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceBoolOp; @@ -64,12 +65,6 @@ public class IsNaN extends BaseReduceBoolOp { return "hasNaNs"; } - @Override - public String tensorflowName() { - return "hasNans"; - } - - @Override public List doDiff(List i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 5ccca2dd2..9ecafd4c1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -89,9 +90,4 @@ public class LogSumExp extends DynamicCustomOp { public String onnxName() { return "ReduceLogSumExp"; } - - @Override - public String tensorflowName() { - return "reduce_logsumexp"; - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java index c324ae68c..bb0dd4997 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java @@ -57,16 +57,6 @@ public class Entropy extends BaseReduceFloatOp { return "entropy"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "entropy"; - } - @Override public Type getOpType() { return Type.REDUCE_FLOAT; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java index 00bd2dbd1..f61c0dc43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.floating; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; import org.nd4j.linalg.ops.transforms.Transforms; @@ -80,8 +81,4 @@ public class Norm2 extends BaseReduceFloatOp { return "Norm"; } - @Override - public String tensorflowName() { - return "norm"; - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java index f90b871e8..d27215a80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java @@ -54,17 +54,6 @@ public class CountNonZero extends BaseReduceLongOp { return "countNonZero"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx name found for shape " + opName()); - } - - @Override - public String tensorflowName() { - return "count_nonzero"; - } - @Override public List doDiff(List f1) { return Collections.singletonList(f().zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java index 29c7eb878..cddf4bdea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java @@ -92,9 +92,4 @@ public class CosineDistance extends BaseReduce3Op { List diff = CosineSimilarity.doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); return Arrays.asList(f().neg(diff.get(0)), f().neg(diff.get(1))); } - - @Override - public String tensorflowName() { - return "cosine_distance"; - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java index 01afd45b3..3aa31771a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java @@ -72,16 +72,6 @@ public class ScalarAdd extends BaseScalarOp { return "add_scalar"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "RealAdd"; - } - @Override public List doDiff(List i_v1) { SDVariable g = i_v1.get(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMax.java index 86dc551fb..3544d3d64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMax.java @@ -58,16 +58,6 @@ public class ScalarMax extends BaseScalarOp { return "max_scalar"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "RealMax"; - } - @Override public List doDiff(List i_v1) { SDVariable mask = arg().gt(scalarValue.getDouble(0)).castTo(arg().dataType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMin.java index 11a5701a6..29399eb5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMin.java @@ -52,16 +52,6 @@ public class ScalarMin extends BaseScalarOp { return 13; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "RealMin"; - } - @Override public String opName() { return "scalar_min"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMultiplication.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMultiplication.java index 955152f1a..5a94eae0b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMultiplication.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarMultiplication.java @@ -66,15 +66,7 @@ public class ScalarMultiplication extends BaseScalarOp { public String opName() { return "mul_scalar"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - @Override - public String tensorflowName() { - return "RealMul"; - } @Override public List doDiff(List i_v1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java index bde8a061a..53645c0ed 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java @@ -66,19 +66,6 @@ public class ScalarSubtraction extends BaseScalarOp { return "sub_scalar"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "RealSub"; - } - - - @Override public List doDiff(List i_v1) { SDVariable g = i_v1.get(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java index 12b1f8378..ce9cf9446 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; @@ -61,12 +62,6 @@ public class ScalarAnd extends BaseScalarBoolOp { return "AndScalar"; } - @Override - public String tensorflowName() { - return "AndScalar"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java index 21d57fc6c..662535648 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; @@ -87,11 +88,4 @@ public class ScalarEps extends BaseScalarBoolOp { public String onnxName() { return "ScalarEps"; } - - @Override - public String tensorflowName() { - return "ScalarEps"; - } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java index 5fb8046c3..ad2aa9b50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java @@ -83,15 +83,4 @@ public class ScalarEquals extends BaseScalarBoolOp { return Arrays.asList(sameDiff.zerosLike(arg())); } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "equal"; - } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java index 34caac083..eafbcbc1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java @@ -77,16 +77,4 @@ public class ScalarGreaterThan extends BaseScalarBoolOp { return Arrays.asList(sameDiff.zerosLike(arg())); } - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "greater"; - } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java index 656cc851c..0948c01ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java @@ -70,18 +70,6 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp { return "greaterthanorequal_scalar"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "greater_equal"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java index 09a317985..412b024f1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java @@ -63,12 +63,6 @@ public class ScalarLessThan extends BaseScalarBoolOp { return "Less"; } - @Override - public String tensorflowName() { - return "less"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java index 38b5947e5..6c9a3a893 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java @@ -55,7 +55,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp { @Override public int opNum() { - return 5; + return 11; } @Override @@ -63,17 +63,6 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp { return "lessthanorequal_scalar"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "less_equal"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java index e35aba47a..6e2e1c885 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; @@ -66,12 +67,6 @@ public class ScalarNot extends BaseScalarBoolOp { return "NotScalar"; } - @Override - public String tensorflowName() { - return "Not_Scalar"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java index 8e3065fc8..f050b686e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java @@ -63,17 +63,6 @@ public class ScalarNotEquals extends BaseScalarBoolOp { return "notequals_scalar"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - return "logical_not"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java index 08161e80e..5fa7ccc8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; @@ -66,11 +67,6 @@ public class ScalarOr extends BaseScalarBoolOp { return "OrScalar"; } - @Override - public String tensorflowName() { - return "Or_Scalar"; - } - @Override public List doDiff(List f1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java index 325d47347..8319d3b6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; @@ -66,12 +67,6 @@ public class ScalarXor extends BaseScalarBoolOp { return "Xor_scalar"; } - @Override - public String tensorflowName() { - return "Xor_scalar"; - } - - @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java index 134869104..7442c390e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -52,12 +53,6 @@ public class ApplyGradientDescent extends DynamicCustomOp { return "ApplyGradientDescent"; } - @Override - public String tensorflowName() { - return "ApplyGradientDescent"; - } - - @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { /* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java new file mode 100644 index 000000000..0b966008c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java @@ -0,0 +1,138 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.shape; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * This operation creates a new, optionally nullified, array with a given shape, order and data type + * + * @author raver119@gmail.com + */ +@Slf4j +public class Create extends DynamicCustomOp { + + protected boolean initialize = false; + protected char order = 'c'; + protected DataType outputType = DataType.FLOAT; //Allow customizing dtype for TF import + + public Create() { + } + + public Create(String name, SameDiff sameDiff, SDVariable input, boolean initialize) { + this(name, sameDiff, input, 'c', initialize, input.dataType()); + } + + public Create(String name, SameDiff sameDiff, SDVariable input, char order, boolean initialize, DataType dataType) { + super(name, sameDiff, new SDVariable[]{input}, false); + this.outputType = dataType; + this.initialize = initialize; + this.order = order; + + addArgs(); + } + + public Create(INDArray shape, DataType dataType) { + this(shape, 'c', false, dataType); + } + + public Create(INDArray shape, boolean initialize, DataType dataType) { + this(shape, 'c', initialize, dataType); + } + + public Create(@NonNull INDArray shape, char order, boolean initialize, DataType dataType) { + super(new INDArray[]{shape}, new INDArray[0]); + this.order = order; + this.initialize = initialize; + this.outputType = dataType; + + addArgs(); + } + + protected void addArgs() { + addBArgument(initialize); + addIArgument((int) order,outputType.toInt()); + } + + @Override + public String opName() { + return "create"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No op found for " + opName()); + } + + @Override + public String tensorflowName() { + return "Empty"; + } + + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + // convert output data type + if(attributesForNode.containsKey("dtype")) { + outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType()); + } + + // get init field + if(attributesForNode.containsKey("init")) { + initialize = attributesForNode.get("init").getB(); + } + + // there's no order in TF, just plain C + this.order = 'c'; + addArgs(); + } + + @Override + public List doDiff(List i_v) { + SDVariable ret = sameDiff.zerosLike(outputVariables()[0]); + return Arrays.asList(ret); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.size() == 1, "Expected list with exactly 1 datatype for %s, got %s", getClass(), dataTypes); + if(outputType != null){ + return Collections.singletonList(outputType); + } else { + //Output type is same as input type + return dataTypes; + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index c4747a371..3472be2de 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -18,14 +18,10 @@ package org.nd4j.linalg.api.ops.impl.shape; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -112,17 +108,6 @@ public class Eye extends DynamicCustomOp { addTArgument((double) dataType.toInt()); } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "Eye"; - } - - @Override public String opName() { return "eye"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index ec86c6553..448ae1d16 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -21,10 +21,8 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -69,16 +67,6 @@ public class MergeAvg extends DynamicCustomOp { super.initFromOnnx(node, initWith, attributesForNode, graph); } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "MergeAvg"; - } - @Override public List doDiff(List i_v) { int nArgs = args().length; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index 046f06c3c..11578b902 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -22,11 +22,8 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -68,16 +65,6 @@ public class MergeMax extends DynamicCustomOp { super.initFromOnnx(node, initWith, attributesForNode, graph); } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "MergeMax"; - } - @Override public List doDiff(List i_v) { SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java index 3a1605d8b..c9eddf633 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java @@ -21,7 +21,6 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; @@ -46,17 +45,6 @@ public class ParallelStack extends DynamicCustomOp { super(null, sameDiff, values, false); } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "parallel_stack"; - } - - @Override public String opName() { return "parallel_stack"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index 4bf920da9..310e90352 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -21,7 +21,6 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; @@ -115,12 +114,6 @@ public class Repeat extends DynamicCustomOp { return "Repeat"; } - @Override - public String tensorflowName() { - return "Repeat"; - } - - @Override public List doDiff(List i_v) { SDVariable ret = outputVariables()[0]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index 67454e231..60da43dc0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -18,11 +18,9 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxMl; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; @@ -97,17 +95,6 @@ public class SequenceMask extends DynamicCustomOp { return "sequence_mask"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx opName found for " + opName()); - } - - - @Override - public String tensorflowName() { - return "SequenceMask"; - } - @Override public List doDiff(List grad){ //Input is integer indices diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java index abc55d8ec..66eeb9b99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java @@ -18,9 +18,6 @@ package org.nd4j.linalg.api.ops.impl.transforms; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -45,16 +42,6 @@ public class Angle extends DynamicCustomOp { return "zeros_like"; } - @Override - public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "Angle"; - } - - @Override public List doDiff(List i_v) { return Collections.singletonList(f().zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java index 939ed854b..9c8607b98 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java @@ -80,7 +80,8 @@ public class MaxOut extends BaseTransformOp { @Override public String tensorflowName() { - return "Maxout"; + throw new NoOpNameFoundException("Tensorflow name not found for " + opName()); + //return "Maxout"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java index 676a016df..4dd948b4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java @@ -20,6 +20,7 @@ import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformBoolOp; @@ -59,12 +60,12 @@ public class BooleanNot extends BaseTransformBoolOp { @Override public String onnxName() { - return "not_applicable"; + throw new NoOpNameFoundException("Onnx name not found for " + opName()); } @Override public String tensorflowName() { - return "not_applicable"; + throw new NoOpNameFoundException("Tensorflow name not found for " + opName()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java index 0ae4b266c..dea1c9c3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java @@ -53,6 +53,11 @@ public class MatchConditionTransform extends BaseTransformBoolOp { public MatchConditionTransform() {} + public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray y, @NonNull INDArray z, @NonNull Condition condition) { + this(x, z, Nd4j.EPS_THRESHOLD, condition); + this.y = y; + } + public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) { this(x, z, Nd4j.EPS_THRESHOLD, condition); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java index 5483ce17d..1f4ee0107 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java @@ -61,7 +61,7 @@ public class BitwiseAnd extends BaseDynamicTransformOp { @Override public String tensorflowName() { - return "bitwise_and"; + return "BitwiseAnd"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java index 3d05d883c..367e27497 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java @@ -22,7 +22,6 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; import java.util.Collections; @@ -49,12 +48,6 @@ public class IsNumericTensor extends DynamicCustomOp { return "is_numeric_tensor"; } - - @Override - public String tensorflowName() { - return "IsNumericTensor"; - } - @Override public List doDiff(List f1) { throw new UnsupportedOperationException(""); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java index cd525a578..02a527cb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java @@ -22,9 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -50,12 +48,6 @@ public class IsStrictlyIncreasing extends DynamicCustomOp { return "is_strictly_increasing"; } - - @Override - public String tensorflowName() { - return "IsStrictlyIncreasing"; - } - @Override public List doDiff(List f1) { return Collections.singletonList(sameDiff.zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogicalXor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogicalXor.java index 02e2bb436..c3c515b25 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogicalXor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogicalXor.java @@ -39,12 +39,6 @@ public class LogicalXor extends DynamicCustomOp { return "boolean_xor"; } - @Override - public String tensorflowName() { - return "LogicalXor"; - } - - @Override public List doDiff(List f1) { return Arrays.asList( sameDiff.zerosLike(larg()), sameDiff.zerosLike(rarg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java index b4aa329a6..d1648abab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java @@ -47,6 +47,29 @@ public class Reverse extends DynamicCustomOp { this.inPlace = true; } + + /** + * This constructor allows to specify axis for Reverse operation + * @param x + * @param axis + */ + public Reverse(INDArray x, int... axis){ + super(new INDArray[]{x}, new INDArray[0]); + this.inPlace = false; + addIArgument(axis); + } + + /** + * This constructor allows to specify axis for Reverse operation + * @param x + * @param axis + */ + public Reverse(INDArray x, INDArray z, int... axis){ + super(new INDArray[]{x}, new INDArray[] {z}); + this.inPlace = false; + addIArgument(axis); + } + /** * Reverses whole array for compatibility with OldReverse. * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java index dbb2eb484..9c4d478c7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java @@ -62,11 +62,6 @@ public class SigmoidDerivative extends DynamicCustomOp { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); } - @Override - public String tensorflowName() { - return "SigmoidGrad"; - } - @Override public List doDiff(List i_v) { throw new UnsupportedOperationException(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java index d21b8a1d6..b49a89200 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java @@ -20,6 +20,7 @@ import lombok.NonNull; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformBoolOp; import org.nd4j.linalg.api.ops.BaseTransformOp; @@ -68,7 +69,8 @@ public class Not extends BaseTransformBoolOp { @Override public String tensorflowName() { - return "Not"; + throw new NoOpNameFoundException("Tensorflow name not found for " + opName()); + //return "Not"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index f1b6be464..1ac8829e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -18,11 +18,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,4 +76,11 @@ public class Abs extends BaseTransformSameOp { SDVariable ret = f().sign(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index 643e940be..ec91a98e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -65,8 +65,10 @@ public class GELU extends BaseTransformStrictOp { } @Override - public String tensorflowName() { - return "GELU"; + public String tensorflowName() + { + throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + //return "GELU"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java new file mode 100644 index 000000000..416f74133 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.strict; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.BaseTransformStrictOp; + +import java.util.Arrays; +import java.util.List; + +/** + * Mish activation function + * + * @author raver119@gmail.com + */ +public class Mish extends BaseTransformStrictOp { + public Mish(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); + } + + public Mish() { + } + + public Mish(INDArray x, INDArray z) { + super(x, z); + } + + public Mish(INDArray ndArray) { + super(ndArray); + } + + @Override + public int opNum() { + return 57; + } + + @Override + public String opName() { + return "mish"; + } + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "Mish"; + } + + + @Override + public List doDiff(List i_v) { + SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0)); + return Arrays.asList(ret); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/MishDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/MishDerivative.java new file mode 100644 index 000000000..ddbbf1d1b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/MishDerivative.java @@ -0,0 +1,79 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.strict; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.BaseTransformStrictOp; + +import java.util.List; + +/** + * Mish derivative + * + * @author raver119@gmail.com + */ +public class MishDerivative extends BaseTransformStrictOp { + public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { + super(sameDiff, i_v1, i_v2); + } + + public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { + super(sameDiff, i_v1, i_v2, inPlace); + } + + public MishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); + } + + public MishDerivative() {} + + public MishDerivative(INDArray x, INDArray z) { + super(x, z); + } + + public MishDerivative(INDArray x) { + super(x); + } + + @Override + public int opNum() { + return 58; + } + + @Override + public String opName() { + return "_mishderivative"; + } + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java index 76ad98f23..ab565b30f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java @@ -65,7 +65,8 @@ public class PreciseGELU extends BaseTransformStrictOp { @Override public String tensorflowName() { - return "PreciseGELU"; + throw new NoOpNameFoundException("Tensorflow name not found for " + opName()); + //return "PreciseGELU"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java index e690413d2..32a823ac1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java @@ -73,7 +73,8 @@ public class DropOut extends BaseRandomOp { @Override public String tensorflowName() { - return opName(); + throw new NoOpNameFoundException("No tensorflow op name found for: " + getClass().getName()); + //return opName(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressionDescriptor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressionDescriptor.java index 89a1ac4a3..42f9153eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressionDescriptor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressionDescriptor.java @@ -21,6 +21,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import java.io.Serializable; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -142,7 +143,7 @@ public class CompressionDescriptor implements Cloneable, Serializable { directAlloc.putLong(numberOfElements); directAlloc.putLong(originalElementSize); directAlloc.putInt(originalDataType.ordinal()); - directAlloc.rewind(); + ((Buffer) directAlloc).rewind(); return directAlloc; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index c95dc5ef2..5e62dd198 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -93,6 +93,7 @@ import org.nd4j.versioncheck.VersionCheck; import java.io.*; import java.lang.reflect.Constructor; import java.math.BigDecimal; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; @@ -5681,7 +5682,7 @@ public class Nd4j { public static INDArray createNpyFromByteArray(@NonNull byte[] input) { ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length); byteBuffer.put(input); - byteBuffer.rewind(); + ((Buffer) byteBuffer).rewind(); Pointer pointer = new Pointer(byteBuffer); return createFromNpyPointer(pointer); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java index e1525f381..f426223ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java @@ -25,71 +25,283 @@ import org.nd4j.linalg.factory.Nd4j; */ public class Conditions { - + /** + * This method will create Condition that checks if value is infinite + * @return + */ public static Condition isInfinite() { return new IsInfinite(); } + /** + * This method will create Condition that checks if value is NaN + * @return + */ public static Condition isNan() { return new IsNaN(); } + /** + * This method will create Condition that checks if value is finite + * @return + */ public static Condition isFinite() { return new IsFinite(); } + /** + * This method will create Condition that checks if value is NOT finite + * @return + */ public static Condition notFinite() { return new NotFinite(); } + /** + * This method will create Condition that checks if value is two values are not equal wrt eps + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition epsNotEquals() { + // in case of pairwise MatchCondition we don't really care about number here + return epsNotEquals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are not equal wrt eps + * + * @return + */ public static Condition epsNotEquals(Number value) { return new EpsilonNotEquals(value); } + /** + * This method will create Condition that checks if value is two values are equal wrt eps + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition epsEquals() { + // in case of pairwise MatchCondition we don't really care about number here + return epsEquals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are equal wrt eps + * + * @return + */ public static Condition epsEquals(Number value) { return epsEquals(value, Nd4j.EPS_THRESHOLD); } + /** + * This method will create Condition that checks if value is two values are equal wrt eps + * + * @return + */ public static Condition epsEquals(Number value, Number epsilon) { return new EpsilonEquals(value, epsilon.doubleValue()); } + /** + * This method will create Condition that checks if value is two values are equal + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition equals() { + // in case of pairwise MatchCondition we don't really care about number here + return equals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are equal + * + * @return + */ public static Condition equals(Number value) { return new EqualsCondition(value); } + /** + * This method will create Condition that checks if value is two values are not equal + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition notEquals() { + // in case of pairwise MatchCondition we don't really care about number here + return notEquals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are not equal + * + * @return + */ public static Condition notEquals(Number value) { return new NotEqualsCondition(value); } + /** + * This method will create Condition that checks if value is value X is greater than value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition greaterThan() { + // in case of pairwise MatchCondition we don't really care about number here + return greaterThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than value Y + * + * @return + */ public static Condition greaterThan(Number value) { return new GreaterThan(value); } + /** + * This method will create Condition that checks if value is value X is less than value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition lessThan() { + // in case of pairwise MatchCondition we don't really care about number here + return lessThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than value Y + * + * @return + */ public static Condition lessThan(Number value) { return new LessThan(value); } + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition lessThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return lessThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y + * + * @return + */ public static Condition lessThanOrEqual(Number value) { return new LessThanOrEqual(value); } + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition greaterThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return greaterThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y + * + * @return + */ public static Condition greaterThanOrEqual(Number value) { return new GreaterThanOrEqual(value); } + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absGreaterThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return absGreaterThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values + * + * @return + */ public static Condition absGreaterThanOrEqual(Number value) { return new AbsValueGreaterOrEqualsThan(value); } + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absLessThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return absLessThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values + * + * @return + */ public static Condition absLessThanOrEqual(Number value) { return new AbsValueLessOrEqualsThan(value); } + /** + * This method will create Condition that checks if value is value X is greater than value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absGreaterThan() { + // in case of pairwise MatchCondition we don't really care about number here + return absGreaterThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than value Y in absolute values + * + * @return + */ public static Condition absGreaterThan(Number value) { return new AbsValueGreaterThan(value); } + /** + * This method will create Condition that checks if value is value X is less than value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absLessThan() { + // in case of pairwise MatchCondition we don't really care about number here + return absLessThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than value Y in absolute values + * + * @return + */ public static Condition absLessThan(Number value) { return new AbsValueLessThan(value); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java index 83f7f5f4c..1ff6c8dfe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -26,16 +27,27 @@ import org.nd4j.linalg.lossfunctions.impl.*; public class LossFunctions { /** - * MSE: Mean Squared Error: Linear Regression
- * EXPLL: Exponential log likelihood: Poisson Regression
- * XENT: Cross Entropy: Binary Classification
- * MCXENT: Multiclass Cross Entropy
- * RMSE_XENT: RMSE Cross Entropy
- * SQUARED_LOSS: Squared Loss
- * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood
+ * MSE: Mean Squared Error: Linear Regression - {@link LossMSE}
+ * l1: L1 loss (absolute value) - {@link LossL1}
+ * XENT: Cross Entropy: Binary Classification - {@link LossBinaryXENT}
+ * MCXENT: Multiclass Cross Entropy - {@link LossMCXENT}
+ * SPARSE_MCXENT: Sparse multi-class cross entropy - {@link LossSparseMCXENT}
+ * SQUARED_LOSS: Alias for mean squared error - {@link LossMSE}
+ * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood - {@link LossNegativeLogLikelihood}
+ * COSINE_PROXIMITY: Cosine proximity loss - {@link LossCosineProximity}
+ * HINGE: Hinge loss - {@link LossHinge}
+ * SQUARED_HINGE: Squared hinge loss - {@link LossSquaredHinge}
+ * KL_DIVERGENCE: Kullback-Leibler divergence loss - {@link LossKLD}
+ * MEAN_ABSOLUTE_ERROR: mean absolute error loss - {@link LossMAE}
+ * L2: L2 loss (sum of squared errors) - {@link LossL2}
+ * MEAN_ABSOLUTE_PERCENTAGE_ERROR: MAPE loss - {@link LossMAPE}
+ * MEAN_SQUARED_LOGARITHMIC_ERROR: MSLE loss - {@link LossMSLE}
+ * POISSON: Poisson loss - {@link LossPoisson}
+ * WASSERSTEIN: Wasserstein loss - {@link LossWasserstein} */ public enum LossFunction { - MSE, L1, @Deprecated EXPLL, XENT, MCXENT, @Deprecated RMSE_XENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, @Deprecated CUSTOM, COSINE_PROXIMITY, HINGE, SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN; + MSE, L1, XENT, MCXENT, SPARSE_MCXENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, COSINE_PROXIMITY, HINGE, + SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN; public ILossFunction getILossFunction() { switch (this) { @@ -48,6 +60,8 @@ public class LossFunctions { return new LossBinaryXENT(); case MCXENT: return new LossMCXENT(); + case SPARSE_MCXENT: + return new LossSparseMCXENT(); case KL_DIVERGENCE: case RECONSTRUCTION_CROSSENTROPY: return new LossKLD(); @@ -68,7 +82,6 @@ public class LossFunctions { case MEAN_SQUARED_LOGARITHMIC_ERROR: return new LossMSLE(); case POISSON: - case EXPLL: return new LossPoisson(); case WASSERSTEIN: return new LossWasserstein(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java index a985734ff..22bb27e0e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -40,10 +41,13 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; /** * * Multi-Class Cross Entropy loss function:
- * L = sum_i actual_i * log( predicted_i ) + * L = sum_i actual_i * log( predicted_i )
+ * Note that labels are represented by a one-hot distribution
+ * See {@link LossSparseMCXENT} for the equivalent but with labels as integers instead * * @author Alex Black, Susan Eraly * @see LossNegativeLogLikelihood + * @see LossSparseMCXENT */ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @@ -53,9 +57,9 @@ public class LossMCXENT implements ILossFunction { @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) - private INDArray weights; + protected INDArray weights; - private double softmaxClipEps; + protected double softmaxClipEps; public LossMCXENT() { this(null); @@ -91,7 +95,7 @@ public class LossMCXENT implements ILossFunction { this.softmaxClipEps = softmaxClipEps; } - private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java index 162ac0797..a8453cf9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java new file mode 100644 index 000000000..2ea0feb52 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java @@ -0,0 +1,133 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.lossfunctions.impl; + + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.shape.OneHot; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.BooleanIndexing; +import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.LossUtil; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; +import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; +import org.nd4j.shade.jackson.annotation.JsonInclude; +import org.nd4j.shade.jackson.annotation.JsonProperty; +import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; +import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; + +/** + * + * Sparse Multi-Class Cross Entropy loss function:
+ * L = sum_i actual_i * log( predicted_i )
+ * Note: this is the same loss function as {@link LossMCXENT}, the only difference being the format for the labels - + * this loss function uses integer indices (zero indexed) for the loss array, whereas LossMCXENT uses the equivalent + * one-hot representation + * + * @author Alex Black + * @see LossNegativeLogLikelihood + * @see LossMCXENT + */ +@EqualsAndHashCode(callSuper = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +@Getter @Setter +public class LossSparseMCXENT extends LossMCXENT { + private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10; + + public LossSparseMCXENT() { + this(null); + } + + /** + * Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a flags scalar value. + * Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size. + * A weight vector of 1s should give identical results to no weight vector. + * + * @param weights Weights array (row vector). May be null. + */ + public LossSparseMCXENT(INDArray weights) { + this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, weights); + } + + /** + * Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a fixed scalar value. + * Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size. + * A weight vector of 1s should give identical results to no weight vector. + * + * @param weights Weights array (row vector). May be null. + */ + public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double softmaxClipEps, @JsonProperty("weights") INDArray weights) { + super(softmaxClipEps, weights); + } + + protected INDArray sparseScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return super.scoreArray(oneHotLabels, preOutput, activationFn, mask); + } + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, + boolean average) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return super.computeScore(oneHotLabels, preOutput, activationFn, mask, average); + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray scoreArr = sparseScoreArray(labels, preOutput, activationFn, mask); + return scoreArr.sum(true,1).muli(-1); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return super.computeGradient(oneHotLabels, preOutput, activationFn, mask); + } + + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, + INDArray mask, boolean average) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return new Pair<>(super.computeScore(oneHotLabels, preOutput, activationFn, mask, average), + super.computeGradient(oneHotLabels, preOutput, activationFn, mask)); + } + + private INDArray toOneHot(INDArray labels, INDArray preOutput){ + Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " + + "with last dimension having size 1. Got labels array with shape %ndShape", labels); + INDArray oneHotLabels = preOutput.ulike(); + Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1))); + return oneHotLabels; + } + + + @Override + public String toString() { + if (weights == null) + return "LossSparseMCXENT()"; + return "LossSparseMCXENT(weights=" + weights + ")"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java index 60bb0378b..ecc4cfe18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java @@ -31,6 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import java.io.*; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.channels.Channels; @@ -215,7 +216,7 @@ public class BinarySerde { allocated.put(shapeBuffer); allocated.put(buffer); if (rewind) - allocated.rewind(); + ((Buffer) allocated).rewind(); } /** @@ -247,7 +248,7 @@ public class BinarySerde { //finally put the data allocated.put(buffer); if (rewind) - allocated.rewind(); + ((Buffer) allocated).rewind(); } @@ -314,7 +315,7 @@ public class BinarySerde { ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder()); - buffer.position(0); + ((Buffer) buffer).position(0); int rank = byteBuffer.getInt(); val result = new long[Shape.shapeInfoLength(rank)]; @@ -324,7 +325,7 @@ public class BinarySerde { // skipping two next values (dtype and rank again) // please , that this time rank has dtype of LONG, so takes 8 bytes. - byteBuffer.position(16); + ((Buffer) byteBuffer).position(16); // filling shape information for (int e = 1; e < Shape.shapeInfoLength(rank); e++) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java index 153183f57..f94ead011 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.util.ArrayUtil; import java.io.File; import java.io.FileInputStream; import java.io.InputStream; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.Charset; @@ -492,8 +493,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8")); ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder()); directBuffer.put(pathBytes); - directBuffer.rewind(); - directBuffer.position(0); + ((Buffer) directBuffer).rewind(); + ((Buffer) directBuffer).position(0); Pointer pointer = nativeOps.numpyFromFile(new BytePointer(directBuffer)); INDArray result = createFromNpyPointer(pointer); @@ -672,8 +673,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8")); ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder()); directBuffer.put(pathBytes); - directBuffer.rewind(); - directBuffer.position(0); + ((Buffer) directBuffer).rewind(); + ((Buffer) directBuffer).position(0); Pointer pointer = nativeOps.mapFromNpzFile(new BytePointer(directBuffer)); int n = nativeOps.getNumNpyArraysInMap(pointer); HashMap map = new HashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index d4a7b8f8b..741978a3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -129,6 +129,7 @@ public interface NativeOps { @Cast("Nd4jLong *") LongPointer resultShapeInfo, Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index de9edfc2e..73bb11bea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -106,7 +106,7 @@ public class NativeOpsHolder { boolean logInit = Boolean.parseBoolean(logInitProperty); if(logInit) { - log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); + log.info("Number of threads used for linear algebra: {}", deviceNativeOps.ompGetMaxThreads()); } } catch (Exception | Error e) { throw new RuntimeException( diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 20f2b5f22..904e1305e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -198,7 +198,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, + null, null, (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null); @@ -805,7 +805,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, + null, null, (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index e8b5e15c9..f567873a2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -916,6 +916,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @@ -927,6 +928,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @@ -938,6 +940,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); @@ -4600,6 +4603,11 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("Nd4jLong") long sizeAt(int dim); + /** + * returns stride of "dim" dimension + */ + public native @Cast("Nd4jLong") long strideAt(int dim); + /** * returns order of array */ @@ -8019,6 +8027,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @@ -8032,6 +8046,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -9088,6 +9108,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + @@ -9613,6 +9635,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include //#include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 67e8b4838..113960951 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -106,10 +106,11 @@ import org.bytedeco.javacpp.tools.InfoMapper; }, compiler = {"cpp11", "nowarnings"}, library = "jnind4jcuda", link = "nd4jcuda", preload = "libnd4jcuda"), - @Platform(value = "linux", preload = "gomp@.1", - preloadpath = {"/lib64/", "/lib/", "/usr/lib64/", "/usr/lib/", - "/usr/lib/powerpc64-linux-gnu/", - "/usr/lib/powerpc64le-linux-gnu/"})}) + @Platform(value = "linux", preload = "gomp@.1", preloadpath = {"/lib64/", "/lib/", "/usr/lib64/", "/usr/lib/"}), + @Platform(value = "linux-armhf", preloadpath = {"/usr/arm-linux-gnueabihf/lib/", "/usr/lib/arm-linux-gnueabihf/"}), + @Platform(value = "linux-arm64", preloadpath = {"/usr/aarch64-linux-gnu/lib/", "/usr/lib/aarch64-linux-gnu/"}), + @Platform(value = "linux-ppc64", preloadpath = {"/usr/powerpc64-linux-gnu/lib/", "/usr/powerpc64le-linux-gnu/lib/", "/usr/lib/powerpc64-linux-gnu/", "/usr/lib/powerpc64le-linux-gnu/"}), + @Platform(value = "windows", preload = {"libwinpthread-1", "libgcc_s_seh-1", "libgomp-1", "libstdc++-6", "libnd4jcpu"}) }) public class Nd4jCudaPresets implements InfoMapper { @Override public void map(InfoMap infoMap) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index a2964b7a6..751f75cea 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -967,6 +967,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { null, null, op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, null, + null, op.dimensions().data().addressPointer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e2e9b0c2f..0ee807594 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -916,6 +916,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @@ -927,6 +928,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @@ -938,6 +940,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); @@ -4600,6 +4603,11 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("Nd4jLong") long sizeAt(int dim); + /** + * returns stride of "dim" dimension + */ + public native @Cast("Nd4jLong") long strideAt(int dim); + /** * returns order of array */ @@ -8019,6 +8027,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @@ -8032,6 +8046,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -9088,6 +9108,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + @@ -11815,6 +11837,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include //#include // #include @@ -17107,6 +17130,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20557,7 +20581,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** + /** * This op make bilinear or nearest neighbor interpolated resize for given tensor * * input array: @@ -20593,7 +20617,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** + /** * This op make bilinear interpolated resize for given tensor * * input array: @@ -20628,7 +20652,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** + /** * This op make nearest neighbor interpolated resize for given tensor * * input array: @@ -20640,7 +20664,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * 1 - new height * * output array: - * the 4D-Tensor with calculated backproped dots + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) * * CAUTION: either size tensor or a pair of int params should be provided. */ @@ -20663,21 +20687,85 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** - * This op calculates backprop dot for two tensors along given dimensions + /** + * This op make bicubic interpolated resize for given tensor * * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) * * output array: - * the tensor with calculated backproped dots + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) + * + */ +// #if NOT_EXCLUDED(OP_resize_bicubic) + @Namespace("nd4j::ops") public static class resize_bicubic extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public resize_bicubic(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public resize_bicubic(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public resize_bicubic position(long position) { + return (resize_bicubic)super.position(position); + } + + public resize_bicubic() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This op make interpolated resize for given tensor with given algorithm. + * Supported algorithms are bilinear, bicubic, nearest_neighbor. + * Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) + * + * optional int args: + * 0 - algorithm - bilinear by default + * optional bool args: + * 0 - preserve_aspect_ratio - default False + * 1 - antialias - default False + * + * output array: + * the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels}) + * + */ + +// #if NOT_EXCLUDED(OP_image_resize) + @Namespace("nd4j::ops") public static class image_resize extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public image_resize(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public image_resize(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public image_resize position(long position) { + return (image_resize)super.position(position); + } + + public image_resize() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * Copy a tensor setting everything outside a central band in each innermost matrix + * + * input array: + * x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN + * + * int arguments: + * lower band + * upper band + * + * output array: + * matrix with given bands between lower and upper diagonals * */ @@ -20717,7 +20805,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif - /* + + /** * image.non_max_suppression op. * input: * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type @@ -21251,6 +21340,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + + /** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ +// #if NOT_EXCLUDED(OP_create) + @Namespace("nd4j::ops") public static class create extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public create(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public create(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public create position(long position) { + return (create)super.position(position); + } + + public create() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 7f3fc8256..9d067b5bc 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -19,6 +19,7 @@ package org.nd4j.nativeblas; import org.bytedeco.javacpp.annotation.Platform; import org.bytedeco.javacpp.annotation.Properties; import org.bytedeco.javacpp.tools.*; +import org.bytedeco.openblas.global.openblas; import java.io.File; import java.io.IOException; @@ -31,7 +32,7 @@ import java.util.Scanner; * * @author saudet */ -@Properties(target = "org.nd4j.nativeblas.Nd4jCpu", helper = "org.nd4j.nativeblas.Nd4jCpuHelper", +@Properties(inherit = openblas.class, target = "org.nd4j.nativeblas.Nd4jCpu", helper = "org.nd4j.nativeblas.Nd4jCpuHelper", value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "memory/MemoryType.h", "array/DataType.h", @@ -127,15 +128,13 @@ import java.util.Scanner; "ops/declarable/headers/third_party.h", "cnpy/cnpy.h" }, - compiler = {"cpp11", "nowarnings"}, library = "jnind4jcpu", link = "nd4jcpu", - preloadresource = {"org/bytedeco/openblas/"}, - preload = {"openblas", "openblas_nolapack", "libnd4jcpu"}), - @Platform(value = "linux", preload = {"gomp@.1"}, preloadpath = {"/lib64/", "/lib/", "/usr/lib64/", "/usr/lib/"}), - @Platform(value = {"linux-arm", "linux-ppc"}, preload = {"gomp@.1", "gcc_s@.1", "quadmath@.0", "gfortran@.5", "gfortran@.4", "gfortran@.3", "openblas@.0", "libnd4jcpu"}), + compiler = {"cpp11", "nowarnings"}, + library = "jnind4jcpu", link = "nd4jcpu", preload = "libnd4jcpu"), + @Platform(value = "linux", preload = "gomp@.1", preloadpath = {"/lib64/", "/lib/", "/usr/lib64/", "/usr/lib/"}), @Platform(value = "linux-armhf", preloadpath = {"/usr/arm-linux-gnueabihf/lib/", "/usr/lib/arm-linux-gnueabihf/"}), @Platform(value = "linux-arm64", preloadpath = {"/usr/aarch64-linux-gnu/lib/", "/usr/lib/aarch64-linux-gnu/"}), @Platform(value = "linux-ppc64", preloadpath = {"/usr/powerpc64-linux-gnu/lib/", "/usr/powerpc64le-linux-gnu/lib/", "/usr/lib/powerpc64-linux-gnu/", "/usr/lib/powerpc64le-linux-gnu/"}), - @Platform(value = "windows", preload = {"libwinpthread-1", "libgcc_s_seh-1", "libgomp-1", "libstdc++-6", "msvcr120", "libnd4jcpu"}), + @Platform(value = "windows", preload = {"libwinpthread-1", "libgcc_s_seh-1", "libgomp-1", "libstdc++-6", "libnd4jcpu"}), @Platform(extension = {"-avx512", "-avx2"}) }) public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 92022d258..ed3b5a7cb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -1641,6 +1641,19 @@ public class SameDiffTests extends BaseNd4jTest { } } + @Ignore(/*AS - 20191114 https://github.com/eclipse/deeplearning4j/issues/8393*/) + @Test + public void testIsStrictlyIncShape() { + int nOut = 0; + int minibatch = 0; + + INDArray ia = Nd4j.randn(minibatch, nOut); + INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape()); + + Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + System.out.println(expOut); + } + @Test public void testExpandDims2d() { val origShape = new long[]{3, 4}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index c57f8c5d9..83145a048 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -40,6 +41,7 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.base.Preconditions; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.imports.listeners.ExecPrintListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -137,7 +139,7 @@ public class TFGraphTestAllHelper { protected static void checkOnlyOutput(Map inputs, Map predictions, String modelName, String baseDir, String modelFilename, ExecuteWith execType, BiFunction loader, - Double maxRelErrorOverride, Double minAbsErrorOverride) throws IOException { + Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException { Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; @@ -152,7 +154,7 @@ public class TFGraphTestAllHelper { outputsToCheck.add(s); } - Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck); + Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck, printArraysDebugging); SameDiff graph = p.getFirst(); Map sameDiffPredictions = p.getSecond(); @@ -296,18 +298,18 @@ public class TFGraphTestAllHelper { } public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName, - ExecuteWith execType, File localTestDir) throws IOException { - checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir); + ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException { + checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging); } public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName, ExecuteWith execType, BiFunction loader, - Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir) throws IOException { + Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException { Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order - Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null); + Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging); SameDiff graph = p.getFirst(); Map sdPredictions = p.getSecond(); @@ -388,13 +390,17 @@ public class TFGraphTestAllHelper { public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, - Set requiredOutputs) throws IOException { + Set requiredOutputs, boolean printArraysDebugging) throws IOException { log.info("\n\tRUNNING TEST " + modelName + "..."); SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); if(listeners != null){ graph.setListeners(listeners); } + if(printArraysDebugging){ + graph.addListeners(new ExecPrintListener()); + } + if(requiredOutputs == null){ requiredOutputs = graph.variableMap().keySet(); } @@ -635,7 +641,7 @@ public class TFGraphTestAllHelper { for (int i = 0; i < resources.size(); i++) { URI u = resources.get(i).getFirst().getURI(); String varName = u.toString(); - int idx = varName.lastIndexOf(modelName); + int idx = varName.indexOf(modelName); varName = varName.substring(idx + modelName.length()+1); //+1 for "/" varName = varName.replaceAll("____","/"); varName = varName.replaceAll(".placeholder.shape",""); @@ -752,7 +758,8 @@ public class TFGraphTestAllHelper { return (t, s) -> Nd4j.sort(t, true).equals(Nd4j.sort(s, true)); } - if(modelName.startsWith("alpha_dropout") || modelName.startsWith("layers_dropout")) + if(modelName.startsWith("alpha_dropout") || modelName.startsWith("layers_dropout") || modelName.equals("dropout")) + //We can't compare dropout using simple equality due to randomness return (t, s) -> { double[] tfNums = t.ravel().toDoubleVector(); double[] sdNums = s.ravel().toDoubleVector(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java index df7f4726a..a9870ea82 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -182,7 +183,7 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, - TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 0d25f63d4..a690bc5a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -70,8 +71,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a //Still failing 2019/09/11 "slogdet/.*", + // Failing 2019/11/14 - |https://github.com/eclipse/deeplearning4j/issues/8374 + "adjust_contrast/*", + "adjust_contrast/.*", //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 "bincount/.*", + // Failing 2019/11/15 https://github.com/eclipse/deeplearning4j/issues/8400 + "bitcast/.*", + // Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393 + "is_strictly_increasing/emptyArrayTest/.*", //TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too "truncatemod/.*", @@ -100,7 +108,25 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "multinomial/.*", //2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation - "conv3d_transpose.*" + "conv3d_transpose.*", + + //2019/11/15 - mapping is not present yet https://github.com/eclipse/deeplearning4j/issues/8397 + "ragged/reduce_mean/.*", + + // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 + "zeros_like/rank2_float32_dtype_int.*", + + // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8399 + "crop_and_resize.*", + + // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8401 + "draw_bounding_boxes.*", + + // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 + "fake_quant/min_max_args_per_channel.*", + + // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403 + "resize_bilinear/int32.*" }; @BeforeClass @@ -169,7 +195,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); try { - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); } catch (Throwable t){ log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 3a39dac37..b0b344f64 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20,11 +21,9 @@ import org.junit.*; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOpsHolder; @@ -51,9 +50,12 @@ public class TFGraphTestList { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + //Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step + //Implemented internally using ExecPrintListener + public static final boolean printArraysDebugging = false; + public static String[] modelNames = new String[]{ -// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" - "accumulate_n/rank0" + "resize_nearest_neighbor/int32" }; @After @@ -102,7 +104,7 @@ public class TFGraphTestList { Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, - TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging); } @Test @Ignore @@ -110,7 +112,6 @@ public class TFGraphTestList { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); File dir = testDir.newFolder(); Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); - TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir); - + TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index 05edef2b8..d08fb5148 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -274,7 +275,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we currentTestDir = testDir.newFolder(); log.info("----- SameDiff Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, - LOADER, maxRE, minAbs); + LOADER, maxRE, minAbs, false); if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){ log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: "); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java new file mode 100644 index 000000000..6eff6b157 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java @@ -0,0 +1,29 @@ +package org.nd4j.imports.listeners; + +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * A very quick and dirty debugging listener + * This listener just prints the outputs of any ops during execution + * @author Alex Black + */ +public class ExecPrintListener extends BaseListener { + @Override + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------"); + for(INDArray arr : outputs){ + System.out.println(arr); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 556405c14..ca075c872 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.custom; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Ignore; @@ -29,18 +30,23 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; +import org.nd4j.linalg.api.ops.impl.controlflow.Where; +import org.nd4j.linalg.api.ops.impl.image.CropAndResize; +import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; +import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; +import org.nd4j.linalg.api.ops.random.impl.DropOut; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.ArrayList; @@ -823,6 +829,17 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, out); } + @Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374") + @Test + public void testAdjustContrastShape(){ + DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2") + .addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f)) + .build(); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + assertArrayEquals(new long[]{256, 256, 3}, lsd.get(0).getShape()); + } + @Test public void testAdjustContrastV2() { INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3); @@ -840,6 +857,16 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, out); } + @Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374") + @Test + public void testBitCastShape(){ + INDArray out = Nd4j.createUninitialized(1,10); + BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); + } + @Test public void testBitCast() { INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); @@ -852,6 +879,79 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, out); } + @Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374") + @Test + public void testDrawBoundingBoxesShape() { + INDArray images = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f,0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f, + 0.3087f,0.1548f,0.4695f,0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f, + 0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,0.2585f,0.4189f,0.7028f,0.7679f,0.5373f, + 0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,0.0755f,0.6245f,0.3491f, + 0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}).reshape(2,5,5,1); + INDArray boxes = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f, + 0.7876f, 0.8945f, 0.4638f, 0.7157f}).reshape(2,2,4); + INDArray colors = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f}).reshape(1,2); + INDArray output = Nd4j.create(DataType.FLOAT, images.shape()); + val op = new DrawBoundingBoxes(images, boxes, colors, output); + Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, + 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, + 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}); + assertEquals(expected, output); + } + + @Ignore(" 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402") + @Test + public void testFakeQuantAgainstTF_1() { + INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); + INDArray min = Nd4j.createFromArray(new float[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + INDArray max = Nd4j.createFromArray(new float[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, + 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, + 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); + + INDArray out = Nd4j.createUninitialized(x.shape()); + val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out); + Nd4j.exec(op); + assertEquals(expected, out); + + /*TF: [[ 0.7801, 0.5966, 0.7260, 0.2320, 0.5084], + [ 0.1800, 0.5046, 0.8684, 0.3513, 0.5084], + [ 0.0877, 0.5966, 0.6600, 0.3513, 0.1604]] + SD: [[ 0.7770, 0.5969, 0.7232, 0.2310, 0.5098], + [ 0.1793, 0.5053, 0.8685, 0.3500, 0.5098], + [ 0.0874, 0.5969, 0.6574, 0.3500, 0.1597]]*/ + } + + @Test + public void testWhereFail() { + INDArray in = Nd4j.createFromArray(new float[]{0f, 1.0000f, 1.0000f, 1.0000f, 1.0000f}); + INDArray out = Nd4j.createUninitialized(4,1); + INDArray expected = Nd4j.createFromArray(4,1); + val op = new Where(new INDArray[]{in}, new INDArray[]{out}); + Nd4j.exec(op); + assertArrayEquals(new long[]{4,1} , out.shape()); + } + + @Ignore("2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403") + @Test + public void testResizeBilinear1() { + + INDArray x = Nd4j.rand(1, 2,3,4); + INDArray z = Nd4j.createUninitialized(x.shape()); + boolean align = false; + val op = new ResizeBilinear(x, z, 10, 10, align); + Nd4j.exec(op); + } + @Test public void testCompareAndBitpack() { INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, @@ -932,6 +1032,30 @@ public class CustomOpsTests extends BaseNd4jTest { System.out.println(distance); } + @Ignore("2019/11/15 AS - https://github.com/eclipse/deeplearning4j/issues/8399") + @Test + public void testCropAndResize() { + INDArray image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1); + INDArray boxes = Nd4j.createFromArray(new float[]{1,2,3,4}).reshape(1,4); + INDArray box_indices = Nd4j.createFromArray(new int[]{1}); + INDArray crop_size = Nd4j.createFromArray(new int[]{1,2}).reshape(1,2); + + //Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1] + INDArray output = Nd4j.create(DataType.FLOAT, 2,2,1,1); + + + Nd4j.exec(new CropAndResize(image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5, + output)); + } + + @Test + public void testLayersDropoutFail() { + INDArray input = Nd4j.rand(4, 5); + INDArray output = Nd4j.createUninitialized(4, 5); + DropOut op = new DropOut(input, output, 0.1); + Nd4j.exec(op); + System.out.println(output); + } @Test public void testRange(){ @@ -963,4 +1087,33 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(1, lsd.size()); assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + + @Test + public void testMatch_1() { + INDArray x = Nd4j.ones(DataType.FLOAT, 3,3); + INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3); + val c = Conditions.equals(0.0); + + System.out.println("Y:\n" + y); + + INDArray z = x.match(y, c); + INDArray exp = Nd4j.createFromArray(new boolean[][]{ + {false, false, false}, + {false, false, false}, + {true, false, false} + }); + + assertEquals(exp, z); + } + + + @Test + public void testCreateOp_1() { + val shape = Nd4j.createFromArray(new int[] {3, 4, 5}); + val exp = Nd4j.create(DataType.INT, 3, 4, 5); + + val result = Nd4j.exec(new Create(shape, 'c', true, DataType.INT))[0]; + + assertEquals(exp, result); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java index d01200db8..0e197f6cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java @@ -50,7 +50,7 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest { @Test public void testL2() { - LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.EXPLL, LossFunction.XENT, + LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT, LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY, LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE, LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR, diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java index 53cc03e70..6aaeae42d 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java @@ -32,6 +32,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; +import java.nio.Buffer; import java.nio.ByteBuffer; /** @@ -129,7 +130,7 @@ public class AeronNDArrayPublisher implements AutoCloseable { NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128); for (int i = 0; i < chunks.length; i++) { ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]); - sendBuff.rewind(); + ((Buffer) sendBuff).rewind(); DirectBuffer buffer = new UnsafeBuffer(sendBuff); sendBuffer(buffer); } diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java index 7a0612982..469bf805b 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.Serializable; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.time.Instant; @@ -229,7 +230,7 @@ public class NDArrayMessage implements Serializable { for (int i = 0; i < chunks.length; i++) { ByteBuffer curr = chunks[i].getData(); if (curr.capacity() > chunks[0].getChunkSize()) { - curr.position(0).limit(chunks[0].getChunkSize()); + ((Buffer) curr).position(0).limit(chunks[0].getChunkSize()); curr = curr.slice(); } all.put(curr); @@ -311,7 +312,7 @@ public class NDArrayMessage implements Serializable { //rewind the buffer before putting it in to the unsafe buffer //note that we set rewind to false in the do byte buffer put methods - byteBuffer.rewind(); + ((Buffer) byteBuffer).rewind(); return new UnsafeBuffer(byteBuffer); } diff --git a/pom.xml b/pom.xml index f907a58ed..4a0992979 100644 --- a/pom.xml +++ b/pom.xml @@ -350,6 +350,7 @@ 2.8.0 1.2.0-3f79e055 4.10.0 + 3.8.3 1.10.0 1.14.0