Update codegen, add ctc loss
This commit is contained in:
		
							parent
							
								
									368ebb7e26
								
							
						
					
					
						commit
						228f6cda30
					
				
							
								
								
									
										614
									
								
								contrib/codegen-tools/codegen/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										614
									
								
								contrib/codegen-tools/codegen/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,614 @@ | ||||
| # ND4J Op Definitions and Code Generation | ||||
| This project contains the ND4J Op definitions, the DSL (Domain Specific Language) that is used for those definitions and | ||||
| code generators that use those definitions to create the actual Java code that is used to use the defined operations. | ||||
| 
 | ||||
| 
 | ||||
| ## Why define ops externally? | ||||
| As we started to support SameDiff, we also started to introduce inconsistencies between SameDiff and ND4J. Even though  | ||||
| both of those libraries use the same underlying implementations for operations, there are both small and large | ||||
| differences in the API that we provide for them. Sometimes, we have provided an official API only for one usage, and not | ||||
| the other. And very often the documentation for a single op is in many different places. | ||||
| 
 | ||||
| In the future we want to support other programming languages with libnd4j, and provide more ways to use our C++ backend. | ||||
| This would only increase the aforementioned problems. | ||||
|   | ||||
| The root of all of those problems, is that Ops are used across different environments, and there is no single way of  | ||||
| defining them with an enforced interface. | ||||
| 
 | ||||
| 
 | ||||
| ## How does this project help with enforcing a single consistent interface for ops? | ||||
| The solution we propose, is to define the operations separately, and then generate the necessary API code for them. All | ||||
| of the generated code is to be considered untouchable, editing it will result in the changes being overwritten sooner | ||||
| rather than later. | ||||
| 
 | ||||
| The combination of external op definition and code generation, opens up many opportunities for us. The first one being | ||||
| that we can easily create consistent APIs for both ND4J and SameDiff in Java. But, looking into the future, we can also  | ||||
| create those APIs for other programming languages like Python, Swift, or even C#. We can even go beyond programming | ||||
| languages, and use the op definitions to create better documentation than what JavaDoc or similar might support out of | ||||
| the box. | ||||
| 
 | ||||
| ## Maintenance | ||||
| This project is currently maintained by Paul Dubs, with feedback often collected from raver119 and Alex Black. | ||||
|   | ||||
| ## Current Status | ||||
| At the moment we still focus on nailing down an easily readable and contribution friendly DSL for op definition and code | ||||
| generation that can replace namespace definitions. This means that at the moment we still rely on the pre-existing Op | ||||
| definition classes that already exist in ND4J. | ||||
| 
 | ||||
| ## Roadmap | ||||
| * Replace Bitwise and Random namespaces with autogenerated code – In progress. | ||||
| * Implement a convenient CLI tool. | ||||
| * Define all Ops using the DSL.  | ||||
| * Automatically generate derivative op declarations from existing ops | ||||
| * Replace all namespace definitions in ND4J / SameDiff with automatically generated ones | ||||
| * Replace all Op classes with automatically generated ones. | ||||
| 
 | ||||
| # Usage | ||||
| Pre-requisites: | ||||
| * JDK 8 or higher | ||||
| * Maven 3.3 or higher | ||||
| 
 | ||||
| TODO: Show usage output of the project itself | ||||
| 
 | ||||
| TODO: Show how to use from mvn | ||||
| 
 | ||||
| 
 | ||||
| ## Generating Code - ND4J Namespaces | ||||
| 
 | ||||
| A script - `generate.sh` - is provided in the project root. This can be used (at present) to generate ND4J namespace classes. | ||||
| It is assumed that the deeplearning4j mono repo and the dl4j-dev-tools repo both exist and have a common parent directory | ||||
| i.e., `somedir/deeplearning4j` and `somedir/dl4j-dev-tools` both exist. | ||||
| 
 | ||||
| The script takes as argument the name (or names) of the ND4J namespaces to generate (not case sensitive) and projects (supported | ||||
| projects are nd4j, sd and both by default). | ||||
| 
 | ||||
| As of 26/11, namespaces names (and hence valid args) include: `bitwise`, `neuralnetwork`, `random`, and `math` | ||||
| Note also that `all` may be passed to the script to generate all namespaces. | ||||
| 
 | ||||
| For example, to generate both bitwise and random namespaces for both nd4j and SameDiff: | ||||
| ``` | ||||
| ./generate.sh bitwise,random | ||||
| ``` | ||||
| Or to generate all namespaces for both nd4j and SameDiff, use: | ||||
| ``` | ||||
| ./generate.sh all | ||||
| ```  | ||||
| To generate namespaces for one project only, use: | ||||
| ``` | ||||
| ./generate.sh linalg -projects sd | ||||
| ``` | ||||
| or: | ||||
| ``` | ||||
| ./generate.sh linalg -projects nd4j | ||||
| ``` | ||||
| The script will first compile the project, before running. | ||||
| Internally, the `org.nd4j.codegen.cli.CLI` class is used. | ||||
| Classes are written to `deeplearning4j/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/` | ||||
| 
 | ||||
| ## Generating documentation. | ||||
| It is possible to use generate.sh for generation of code only, docs in markdown format only, or both docs and code. | ||||
| To generate docs only and store them to new folder "docs" for all namespaces: | ||||
| ``` | ||||
| ./generate.sh all -docsdir ../../docs | ||||
| ``` | ||||
| Generation for selected namespaces works in the same way as for code: | ||||
| ``` | ||||
| ./generate.sh -docsdir ../../docs bitwise,linalg | ||||
| ``` | ||||
| 
 | ||||
| # Code structure | ||||
| The project is implemented using a mix of Java and Kotlin. The DSL definition and the accompanying data structures are | ||||
| implemented in Kotlin. At the moment the code generators are implemented in Java, in order to allow people who are not | ||||
| fluent in Kotlin, but know Java to be able to contribute to the code generators. | ||||
| 
 | ||||
| The source code for this project is structured a bit different that what you would typically see in a Java or Kotlin | ||||
| project. When you take a look inside the `src/main` directory, you will find 4 sub-directories.  | ||||
| 
 | ||||
| The `java` and `kotlin` directories contain Java and Kotlin code respectively.  | ||||
| 
 | ||||
| In order to not confuse op definitions with the machinery that allows them to be defined in that way, ops are kept in a | ||||
| separate folder called `ops`. | ||||
|   | ||||
| Because we use JavaPoet for Java code generator implementation, we also have yet another folder called `stubs`. That | ||||
| folder contains stub classes, that are used to reference other classes available in ND4J. These stub classes are | ||||
| intentionally left empty, as JavaPoet only requires them for naming and automatically creating proper imports. We use | ||||
| stub classes instead of depending on the actual nd4j API in order to break a cyclic dependency that would otherwise be | ||||
| created (i.e. in order to be able to generate code for ND4J, we would need an already compiled nd4j to be available). | ||||
| **Note:** If something is stubbed here and is moved in ND4J, then it also has to be moved to the appropriate place here, | ||||
| otherwise the generated code will be wrong. | ||||
| 
 | ||||
| The `adr` folder contains "Architecture Decision Records". These files give you more insight into the "why" of some of | ||||
| the bigger decisions within this project. | ||||
| 
 | ||||
| # DSL for Op Definition | ||||
| Ops are defined using a DSL that is implemented in Kotlin. This means that other than the DSL, as defined in the | ||||
| following, you can also use all of Kotlin when defining Ops. However, doing things the obvious and clearly | ||||
| understandable way is better than coming up with a clever way, so prefer to use the DSL as described if unsure. | ||||
| 
 | ||||
| ```kotlin | ||||
| val mathNs = Namespace("math") { | ||||
|     Op("add") { | ||||
|         javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic" | ||||
| 
 | ||||
|         Input(NUMERIC, "x") { description = "First input to add" } | ||||
|         Input(NUMERIC,"y") { count = AtLeast(1); description = "Second input to add" } | ||||
|         Arg(INT,"shape") { count = AtLeast(1); description = "shape" } | ||||
| 
 | ||||
| 
 | ||||
|         Output(NUMERIC, "z") { description = "Output (x+y)" } | ||||
| 
 | ||||
|         Doc(Language.ANY, DocScope.ALL) { | ||||
|             """ | ||||
|             (From AddOp) Add op doc text that will appear everywhere - classes, constructors, op creators | ||||
|             """.trimIndent() | ||||
|         } | ||||
|         Doc(Language.ANY, DocScope.CLASS_DOC_ONLY) { | ||||
|             "Add op doc text that will appear in all class docs (javadoc etc)" | ||||
|         } | ||||
|         Doc(Language.ANY, DocScope.CONSTRUCTORS_ONLY) { | ||||
|             "Add op doc text for constructors only" | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| This example shows how a namespace is defined. Namespaces are at the top layer, and ops can only be defined within the | ||||
| context of a namespace. This example namespace contains only a single op, called "add". If we wanted to add another op, | ||||
| we would simply add it below the first. | ||||
| 
 | ||||
| As you can see, every op has to have a name, if you try to create one without a name, you will get a compile error.  | ||||
| Within the context of the op, we first set in which java package the op class can be found in, then define its inputs,  | ||||
| arguments and outputs and finally add some free form documentation about what that op is doing. | ||||
| 
 | ||||
| Like with the op itself, the inputs, arguments and outputs all have to have a name, but unlike the op, they also require | ||||
| a type. Within their context, you can set a description and a count of how many parameters they can take respectively. | ||||
|   | ||||
| If an input, argument or output take anything else than exactly 1, they will be treated as arrays. Typically you would | ||||
| use this to define ops like `concat` which can take multiple input tensors or ops that might take shape arguments. | ||||
| 
 | ||||
| ## Examples | ||||
| The following shows how a typical op definition looks like and how the generated Java code may look. | ||||
| 
 | ||||
| An op might be defined like this: | ||||
| 
 | ||||
| ```kotlin | ||||
| Op("binomial") { | ||||
|     javaPackage = "org.nd4j.linalg.api.ops.random.custom" | ||||
|     Arg(INT, "nTrials") { description = "Number of trials parameter for the binomial distribution" } | ||||
|     Arg(FLOATING_POINT, "p") { description = "Probability of success for each trial" } | ||||
|     Arg(INT, "shape") { count = AtLeast(1); description = "Shape of the new random SDVariable, as a 1D array" } | ||||
| 
 | ||||
|     Output(NUMERIC, "output") { description = "new random SDVariable, where values are randomly sampled according to a Binomial distribution" } | ||||
| 
 | ||||
|     Doc(Language.ANY, DocScope.ALL) { | ||||
|         """ | ||||
|         Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, | ||||
|         with the specified number of trials and probability. | ||||
|         """.trimIndent() | ||||
|     } | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| The java code generator will create a method like the following for it: | ||||
| ```java | ||||
|   /** | ||||
|    * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, | ||||
|    * with the specified number of trials and probability. | ||||
|    * | ||||
|    * @param nTrials Number of trials parameter for the binomial distribution | ||||
|    * @param p Probability of success for each trial | ||||
|    * @param shape Shape of the new random SDVariable, as a 1D array (Size: AtLeast(min=1)) | ||||
|    * @return output new random SDVariable, where values are randomly sampled according to a Binomial distribution (NUMERIC type) | ||||
|    */ | ||||
|   public static INDArray binomial(long nTrials, double p, long... shape) { | ||||
|     Preconditions.checkArgument(shape.length >= 1, "shape has incorrect count. Expected: AtLeast(min=1)"); | ||||
|     return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.BinomialOp(nTrials, p, shape))[0]; | ||||
|   } | ||||
| ``` | ||||
| 
 | ||||
| Or an op with some more constraints: | ||||
| 
 | ||||
| ```kotlin | ||||
| Op("and") { | ||||
|     javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom" | ||||
|     val x = Input(INT, "x") { description = "First input array" } | ||||
|     val y = Input(INT, "y") { description = "Second input array" } | ||||
|     Constraint("Must be same types"){ sameType(x, y) } | ||||
|     Constraint("Must have broadcastable shapes"){ broadcastableShapes(x, y) } | ||||
| 
 | ||||
|     Output(INT, "output"){ description = "Bitwise AND array" } | ||||
| 
 | ||||
|     Doc(Language.ANY, DocScope.ALL){ | ||||
|         """ | ||||
|         Bitwise AND operation. Supports broadcasting. | ||||
|         """.trimIndent() | ||||
|     } | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| will be converted to java like this: | ||||
| 
 | ||||
| ```java | ||||
|   /** | ||||
|    * Bitwise AND operation. Supports broadcasting. | ||||
|    * | ||||
|    * Inputs must satisfy the following constraints:  | ||||
|    * Must be same types: isSameType(x, y) | ||||
|    * Must have broadcastable shapes: isBroadcastableShapes(x, y) | ||||
|    * | ||||
|    * @param x First input array (INT type) | ||||
|    * @param y Second input array (INT type) | ||||
|    * @return output Bitwise AND array (INT type) | ||||
|    */ | ||||
|   public static INDArray and(INDArray x, INDArray y) { | ||||
|     NDValidation.validateInteger("and", x); | ||||
|     NDValidation.validateInteger("and", y); | ||||
|     Preconditions.checkArgument(isSameType(x, y), "Must be same types"); | ||||
|     Preconditions.checkArgument(isBroadcastableShapes(x, y), "Must have broadcastable shapes"); | ||||
|     return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.AndOp(x, y))[0]; | ||||
|   } | ||||
| ``` | ||||
| 
 | ||||
| # Full DSL Description | ||||
| ## Namespace | ||||
| 
 | ||||
|     fun NamespaceName() = Namespace("name"){ /* Op definitions in namespace context */} | ||||
|      | ||||
| Defines a namespace. | ||||
| 
 | ||||
| ## Op | ||||
| Only available within a namespace context | ||||
| 
 | ||||
|     Op("opName") { /* op properties in op context */ } | ||||
|     Op("anotherOp", mixin) { /* op properties in op context */ } | ||||
|     Op("anotherOp2", mixin, keepInputs=false) { /* op properties in op context */ } | ||||
|      | ||||
| Every op requires a namespace unique op name. | ||||
|   | ||||
| When defining an op, you can also pass a mixin that it should inherit initial properties from. This has the same effect | ||||
| as using `useMixin(mixin)` as the very first thing in the op definition. If you don't want to inherit all of the | ||||
| parameters of the mixin, you can pass the same additional configuration as you would pass to | ||||
| `useMixin(mixin, ...options..)`. See [Mixin](#mixin) for more information. | ||||
|   | ||||
| ### Op properties | ||||
| * `javaPackage` (String): Package where the op is to be found in the java implementation. | ||||
| * `javaOpClass` (String): Name of java op class if inconsistent with opName. Default: same as opName | ||||
| * `libnd4jName` (String): The name the op has in libnd4j. Default: same as opName | ||||
| 
 | ||||
| 
 | ||||
| ## Mixin | ||||
| Available in global context. | ||||
| 
 | ||||
|     val mixin = Mixin("name"){ /* op properties in op context */ } | ||||
|     // within an op context: | ||||
|     useMixin(mixin) | ||||
|     useMixin(mixin, ...options...) | ||||
|      | ||||
|     // When needing to access something from within the mixin | ||||
|     mixin.input("name") | ||||
|     mixin.arg("name") | ||||
|     mixin.config("name") | ||||
|     mixin.output("name") | ||||
|      | ||||
| Mixins provide the facility to share commonalities between Ops. You can think of it like inheritance, especially when | ||||
| you declare the use of a mixin on Op definition. In contrast to normal (single) inheritance where only a single super | ||||
| class is possible, the mixin mechanism allows to "inherit" from multiple sources.  | ||||
|   | ||||
| You can define almost all the same things within a mixin that you can within an Op. The only things that *can not* be | ||||
| configured within a mixin are Op `name`, `libnd4jName` and `javaOpClass`. | ||||
| 
 | ||||
| As mixins can be configured within the global context, you can share them across namespaces by defining them in their | ||||
| own file. If a mixin is namespace specific, you can also define it within the namespace context.  | ||||
| 
 | ||||
| Mixins are used either on definition as a parameter `Op("opname", mixin){...}`, or with `useMixin(mixin)` within the op | ||||
| definition. While the former version only supports a single mixin, the latter version allows you to use as many mixins | ||||
| as are required.  | ||||
| 
 | ||||
| You can also build up mixins by using `useMixin(mixin)` inside a Mixin itself. | ||||
| 
 | ||||
| `useMixin(mixin, ...options...)` supports a few additional options: `keepInputs`, `keepArgs`, `keepConfigs`,  | ||||
| `keepOutputs`, `keepSignatures`,  `keepDoc`, `keepConstraints`. They default to `true`. If you want to skip including | ||||
| some of them, you simply set the parameter for it to `false`, e.g. `useMixin(mixin, keepDoc=false)`.  | ||||
| 
 | ||||
| When using `useMixin(mixin)`, all definitions within the mixin are applied as if this invocation was replaced with the | ||||
| content of the mixin itself. This means, that if you have already defined anything prior to using a mixin, the mixin's | ||||
| definitions will be **after** the previously defined things. This can be very useful if the commonality  between ops is  | ||||
| that they have a few trailing options. | ||||
|                                                                | ||||
| If a named property or section is defined in both a mixin (or multiple mixins) and the op, then the **last** to define it will | ||||
| win. Named properties are `legacy`, `javaPackage`, named sections are `Input`, `Arg`, `Output`, `Config`. | ||||
| 
 | ||||
| For example, assume you have `javaPackage` defined in both an op and a mixin. Then you can have the following two | ||||
| cases: | ||||
| 
 | ||||
| First case: | ||||
| ```kotlin | ||||
|     Op("foo"){ | ||||
|         useMixin(exampleMixin) | ||||
|         javaPackage = "some.example.package" | ||||
|     } | ||||
| ```  | ||||
| 
 | ||||
| Second case: | ||||
| ```kotlin | ||||
|     Op("foo"){ | ||||
|         javaPackage = "some.example.package" | ||||
|         useMixin(exampleMixin) | ||||
|     } | ||||
| ```  | ||||
| 
 | ||||
| In the first case, the op will have the `javaPackage` value that is defined within the op. In the second case it will  | ||||
| have the `javaPackage` value defined in the mixin. | ||||
| 
 | ||||
| For inputs, args, outputs, it works similarly. Assume you have `Input(dataType, "a")` defined in both the mixin and the | ||||
| op. Again you can have two cases: | ||||
| 
 | ||||
| First case: | ||||
| ```kotlin | ||||
|     Op("foo"){ | ||||
|         useMixin(exampleMixin) | ||||
|         Input(NUMERIC, "a") | ||||
|     } | ||||
| ```  | ||||
| 
 | ||||
| Second case: | ||||
| ```kotlin | ||||
|     Op("foo"){ | ||||
|         Input(NUMERIC, "a") | ||||
|         useMixin(exampleMixin) | ||||
|     } | ||||
| ```  | ||||
| 
 | ||||
| In the first case, it will overwrite the input from the mixin. In the second case, the mixin will overwrite that the | ||||
| input from the op. | ||||
| 
 | ||||
| ## Config | ||||
| Only available within a namespace context | ||||
| 
 | ||||
|     val nameConfig = Config("Name"){ | ||||
|         /* input, arg, constraint, doc properties */ | ||||
|     } | ||||
| 
 | ||||
| Every config requires a namespace unique name. | ||||
| 
 | ||||
| A config allows to define a configuration class, that can be used as a holder for complex properties of specific ops  | ||||
| which will be passed to an op as a parameter. | ||||
| 
 | ||||
| Similar to an op itself, it supports `Input`, `Arg`, `Constraint` and `Doc` definitions.  | ||||
| 
 | ||||
| in order to use the config within an op you either use `useConfig(cfg)` or `val configRef = useConfig(cfg)`. The second | ||||
| form allows you to reference the config. | ||||
| 
 | ||||
| Referencing the config allows to you reference its inputs and args by name: `configRef.input("name")` and  | ||||
| `configRef.arg("name")`. Also it allows you to use a config in a signature `Signature(a, b, c, configRef)`. | ||||
| 
 | ||||
| When default and shorthand signatures are used, configs will be always placed at the end. | ||||
| 
 | ||||
| If a config is defined but not used, an `IllegalStateException` will be thrown. | ||||
| 
 | ||||
| See also [ADR 0007 "Configuration Objects"](adr/0007-configuration_objects.md). | ||||
| 
 | ||||
| 
 | ||||
| ## Input | ||||
| Available within an op, mixin and a config context | ||||
| 
 | ||||
|     Input(FLOATING_POINT, "b"){ /* input properties in input context */ } | ||||
|     val a = Input(INT, "a"){ /* input properties in input context */ } | ||||
|      | ||||
| Inputs represent tensors. They are what the op will work on. | ||||
| 
 | ||||
| Every input requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name. | ||||
| 
 | ||||
| When defining an input, you can assign it to a variable in order to be able to reference it later on. You might want to | ||||
| do this when defining constraints. | ||||
| 
 | ||||
| If you want an input to represent an array, you will have to set a count accordingly. If no count is set, it is assumed | ||||
| that the count is meant to be `Exactly(1)`. | ||||
|   | ||||
| ### Input properties | ||||
| * `description` (String): A short description what this input represents. Setting this is recommended. | ||||
| * `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)` | ||||
| * `defaultValue` (Input): use another input as the default if this isn't set explicitly. The data type of the other  | ||||
|   input has to match the data type of this input. The other input may also have a default value. | ||||
| 
 | ||||
| ## Argument | ||||
| Available within an op, mixin and config context | ||||
| 
 | ||||
|     Arg(FLOATING_POINT, "b"){ /* Arg properties in arg context */ } | ||||
|     val a = Arg(INT, "a"){ /* Arg properties in arg context */ } | ||||
|      | ||||
| Args represent arguments. They modify how the op works on its inputs. | ||||
| 
 | ||||
| Every arg requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name. | ||||
| 
 | ||||
| When defining an arg, you can assign it to a variable in order to be able to reference it later on. You might want to do | ||||
| this when defining constraints. | ||||
| 
 | ||||
| If you want an arg to represent an array, you will have to set a count accordingly. If no count is set, it is assumed | ||||
| that the count is meant to be `Exactly(1)`. | ||||
| 
 | ||||
| Note (Java specific): If the last arg is defined to represent an array, it will be translated to a vararg parameter, e.g. | ||||
| `Arg(INT, "a"){ count = AtLeast(1); description = "..." }` will be turned into `long... a`. | ||||
|   | ||||
| ### Argument properties | ||||
| * `description` (String): A short description what this argument represents. Setting this is recommended. | ||||
| * `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)` | ||||
| * `defaultValue` (null|Number|Boolean|int[]|double[]|boolean[]|Arg|TensorShapeValue|TensorDataTypeValue|String):  | ||||
|   Use given value as default value, if this isn't explicitly set. Can refer to *inputs* and *outputs* using `x.shape()` | ||||
|   and `x.dataType()`. The given default values has to match the data type for this argument. May also refer to another | ||||
|   Arg, and that Arg may also have a default value. Default values based on outputs are treated like without a default  | ||||
|   in SameDiff mode. | ||||
| * `possibleValues` (String[]): only available when ENUM data type is used for the argument. Takes a list of possible | ||||
|   values for the Enum. If used in in abstract base op, the enum will only be created once. See also  | ||||
|   [ADR 0006 "Op specific enums"](adr/0006-op_specific_enums.md). | ||||
|   | ||||
| 
 | ||||
| ## Output | ||||
| Only available within an op and mixin context | ||||
| 
 | ||||
|     Output(FLOATING_POINT, "b"){ /* Arg properties in arg context */ } | ||||
| 
 | ||||
| Every output requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name. | ||||
| 
 | ||||
| While outputs can be assigned to a variable, there is no intended use-case for it. In contrast to inputs and args,  | ||||
| outputs can not be used in constraints. | ||||
| 
 | ||||
| ### Output properties | ||||
| * `description` (String): A short description what this argument represents. Setting this is recommended. | ||||
| 
 | ||||
| 
 | ||||
| ## Signature | ||||
| Only available within an op and mixin context | ||||
| 
 | ||||
|     Signature(a,b,c) | ||||
|     Signature(a,b,c) { "Some Documentation" } | ||||
|     AllParamSignature() | ||||
|     AllDefaultParamSignature() | ||||
| 
 | ||||
| For some ops only specific signatures make sense, as for example some optional parameters may become required in the | ||||
| presence of other optional parameters. This feature is mainly meant to help with the fact that not all programming | ||||
| languages (e.g. Java) support default parameters. Each signature is meant to describe one overload in those languages.    | ||||
| 
 | ||||
| See also [ADR 0005 "Optional parameters and signatures"](adr/0005-optional_parameters_and_signatures.md). | ||||
| 
 | ||||
| Signatures can also reference the output(s) of an op. Those signatures are only relevant in NDArray programming mode. | ||||
| They are not to be generated in SameDiff mode. | ||||
| 
 | ||||
| `AllParamSignature()` and `AllDefaultParamSignature()` are short hands for `Signature(...all parameters...)` and  | ||||
| `Signature(...only parameters with no default values...)`. Their parameters include references to outputs unless | ||||
| disabled using `withOutput=false` (e.g. `AllParamSignature(withOutput=false)`). | ||||
| 
 | ||||
| If no signature is specified for an op, it is treated as if `AllParamSignature()` and `AllDefaultParamSignature()` are | ||||
| both specified.  | ||||
| 
 | ||||
| Each signature must satisfy the condition, that all required parameters are listed there. If this condition is not | ||||
| satisfied, an `IllegalStateException` will be thrown on construction.  | ||||
|    | ||||
| 
 | ||||
| ## Documentation | ||||
| Only available within an op and mixin context | ||||
| 
 | ||||
|     Doc(Language.ANY, DocScope.ALL){ | ||||
|         """ Some documentation | ||||
|         It can be multiline. And indented. | ||||
|         """.trimIndent() | ||||
|     } | ||||
|      | ||||
| Documentation can be language specific, and can be set to be only given at specific places. The documentation itself is | ||||
| given as a string. Because Kotlin supports multiline strings along with proper indentation, we are using them directly | ||||
| here. | ||||
| 
 | ||||
| Note: At the moment we are only creating java code, so the documentation can use JavaDoc syntax. | ||||
| 
 | ||||
| You can have multiple Doc definitions; they are treated as additive. | ||||
| 
 | ||||
| Any instances of the following values will be replaced when generating code: | ||||
| 
 | ||||
| * `%OPNAME%` -> operation name ("Add", "Sub", etc) | ||||
| * `%LIBND4J_OPNAME%` -> libnd4j op name ("add", "sub", etc) | ||||
| * `%INPUT_TYPE%` -> input / output type depending on the generated api, i.e. `SDVariable` for SameDiff and `INDArray` | ||||
|   for ND4J | ||||
| 
 | ||||
| See `DocTokens` class for more details.  | ||||
| 
 | ||||
| ## Constraints | ||||
| Available within an op, mixin and a config context. | ||||
| 
 | ||||
|     Constraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ } | ||||
|     BackendConstraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ } | ||||
| 
 | ||||
| Many ops expect their inputs and arguments to satisfy some specific rules. Those rules can be expressed with the | ||||
| constraint system. | ||||
| 
 | ||||
| Constraints are to be enforced within the frontend language, while BackendConstraints are currently only to be used as | ||||
| a part of the documentation. They will be enforced within the C++ backend, so there is no point in double checking them. | ||||
| 
 | ||||
| There is a system in place to define even complex constraints for inputs and arguments. | ||||
| 
 | ||||
| In a constraint definition, you can reference inputs and arguments directly, if they are previously assigned to | ||||
| a variable using `val name = Input(...)`. Inside the Constraint block, you can use the following operations: | ||||
| 
 | ||||
| * `eq`: Compare equality (applicable to numbers and booleans), e.g. `x eq 7`, `x eq true` | ||||
| * `neq`: Compare inequality (applicable to numbers and booleans), e.g. `x neq 3`, `x neq true` | ||||
| * `lt`, `lte`: less than, less than equal (applicable to numbers), e.g. `x lt 3`, `x lte 4` | ||||
| * `gt`, `gte`: greater than, grater than equal (applicable to numbers), e.g. `x gt 5`, `x gte 6` | ||||
| * `and`: combine two comparisons where both have to be true, e.g. `(x eq 8) and (y lt 3)` | ||||
| * `or`: combine two comparisons where one has to be true, e.g. `(x eq 8) or (y eq true)` | ||||
| * `all`: combine N comparisons where all have to be true, e.g. `all(x eq 8, y lt 3, z eq true)` | ||||
| * `some`: combine N comparisons where at least one has to be true, e.g. `some(x eq 8, y lt 3, z eq true)` | ||||
| * `not`: negates a comparison, e.g. `not(x eq 3)` | ||||
| 
 | ||||
| In addition to those operations, you also get access to some more complex constraints: | ||||
| * `sameType(...)`: true if all given inputs are the same type, e.g. `sameType(x,y,z)` | ||||
| * `sameShape(...)`: true if all given inputs have the same shape, e.g. `sameShape(x,y,z)` | ||||
| * `broadcastableShapes(...)`: true if all given inputs have broadcast compatible shapes, e.g. `broadcastableShapes(x,y,z)` | ||||
| 
 | ||||
| Inputs also get some additional methods on them to define useful constraints: | ||||
| * `input.rank()`: Rank of the given input | ||||
| * `input.sizeAt(i)`: size of the given input at the i-th dimension | ||||
| * `input.isScalar()`: Short hand for `x.rank() == 1` | ||||
| 
 | ||||
| ### Examples | ||||
| Some examples of constraints, and what they evaluate to. The example code contains a little bit of context. | ||||
| 
 | ||||
| ```kotlin | ||||
| val x = Input(INT, "x") { description = "First input array" } | ||||
| val y = Input(INT, "y") { description = "Second input array" } | ||||
| Constraint("foo bar"){ | ||||
|    x.sizeAt(7) eq 7 and y.isScalar() | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| will evaluate to: | ||||
| ```java | ||||
|     Preconditions.checkArgument((x.sizeAt(7) == 7) && (y.rank() == 1), "foo bar"); | ||||
| ``` | ||||
| 
 | ||||
| More examples (only the constraint itself, without context code): | ||||
| 
 | ||||
| #### Some | ||||
| ```kotlin | ||||
| some(input.rank() eq 3, input.sizeAt(2) gte 7, input.sizeAt(4) lt 5) | ||||
| ``` | ||||
| turns to: | ||||
| ```java | ||||
| ((x.rank() == 3) || (x.sizeAt(2) >= 7)) || (x.sizeAt(4) < 5) | ||||
| ``` | ||||
| 
 | ||||
| # Contributing to this project | ||||
| If you want to contribute to this project other than by adding or improving op definitions, the following sections might | ||||
| be of special interest to you. | ||||
| 
 | ||||
| ## Extending the DSL | ||||
| The DSL is implemented using Kotlin’s type-safe builders feature  | ||||
| (see https://kotlinlang.org/docs/reference/type-safe-builders.html). The basic principle is that functions calls can | ||||
| receive blocks that can be executed in a specified context. When combined with the fact that we are just looking to | ||||
| create an object graph that is then going to be used as input to the code generators, this allows us to create a very  | ||||
| feature rich DSL without actually having to write a lot of code to support it. | ||||
| 
 | ||||
| Most of the DSL specific code can be found in `src/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt`. The actual class | ||||
| definitions for the object graph we are building, can be found in `src/kotlin/org/nd4j/codegen/api`.  | ||||
| 
 | ||||
| If you want to add just a simple field to one of the objects, it is usually enough to directly add it to the particular | ||||
| class. | ||||
| 
 | ||||
| If you want to add a specific section to the op definition, i.e. a section like Input or Doc, you will have to add both | ||||
| the class for the object that it is going to be creating, as well as a function within OpBuilder.kt to create and | ||||
| register that section within the op. | ||||
|   | ||||
| **Note:** When you extend the DSL you will most likely also have to update all code generators to support the feature  | ||||
| you have added. | ||||
|   | ||||
| ## Adding / extending code generators | ||||
| Code generators can be written in either Java or Kotlin. Java has the advantage that more people will have experience in | ||||
| using it. Kotlin has the advantage of more convenient syntax, especially for plain string manipulation and when dealing | ||||
| with Enums and fixed sets of subclasses (called sealed classes in Kotlin). | ||||
| 
 | ||||
| All generators have to implement the `org.nd4j.codegen.api.generator.Generator` interface. For automatic detection by | ||||
| the CLI tool, they should also be within the `org.nd4j.codegen.impl.LANGUAGE` package, where `LANGUAGE` is the actual | ||||
| language that they generate. | ||||
| 
 | ||||
| Code generators can also use an auxiliary generator for constraint generation. Those auxiliary generators, have to | ||||
| implement ` org.nd4j.codegen.api.generator.ConstraintCodeGenerator` interface. | ||||
| 
 | ||||
							
								
								
									
										11
									
								
								contrib/codegen-tools/codegen/generate.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								contrib/codegen-tools/codegen/generate.sh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| #!/bin/bash | ||||
| 
 | ||||
| if test "$#" -eq 0; then | ||||
|     echo "No namespaces were specified. One or more namespaces must be provided as an argument" | ||||
|     echo "Usage example 1 (single namespace):      ./generate.sh math" | ||||
|     echo "Usage example 2 (multiple namespaces):   ./generate.sh math,random" | ||||
|     echo "Usage example 2 (all namespaces):        ./generate.sh all" | ||||
| else | ||||
|     mvn clean package -DskipTests | ||||
|     java -cp target/codegen-1.0.0-SNAPSHOT-shaded.jar org.nd4j.codegen.cli.CLI -dir ../../ -namespaces "$@" | ||||
| fi | ||||
							
								
								
									
										258
									
								
								contrib/codegen-tools/codegen/pom.xml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										258
									
								
								contrib/codegen-tools/codegen/pom.xml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,258 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|          xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|          xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <groupId>org.nd4j</groupId> | ||||
|     <artifactId>codegen</artifactId> | ||||
|     <version>1.0.0-SNAPSHOT</version> | ||||
| 
 | ||||
|     <properties> | ||||
|         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||||
|         <commonsio.version>2.5</commonsio.version> | ||||
|         <commons.lang.version>3.12.0</commons.lang.version> | ||||
|         <commons.dbutils.version>1.7</commons.dbutils.version> | ||||
|         <lombok.version>1.18.8</lombok.version> | ||||
|         <logback.version>1.1.7</logback.version> | ||||
|         <junit.version>4.12</junit.version> | ||||
|         <junit-jupiter.version>5.4.2</junit-jupiter.version> | ||||
|         <java.version>1.8</java.version> | ||||
|         <maven-shade-plugin.version>3.1.1</maven-shade-plugin.version> | ||||
|         <kotlin.version>1.3.50</kotlin.version> | ||||
|         <kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget> | ||||
|         <kotlin.compiler.incremental>true</kotlin.compiler.incremental> | ||||
|     </properties> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.slf4j</groupId> | ||||
|             <artifactId>slf4j-api</artifactId> | ||||
|             <version>1.7.28</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>ch.qos.logback</groupId> | ||||
|             <artifactId>logback-classic</artifactId> | ||||
|             <version>${logback.version}</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>commons-io</groupId> | ||||
|             <artifactId>commons-io</artifactId> | ||||
|             <version>${commonsio.version}</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>org.projectlombok</groupId> | ||||
|             <artifactId>lombok</artifactId> | ||||
|             <version>${lombok.version}</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>org.apache.commons</groupId> | ||||
|             <artifactId>commons-lang3</artifactId> | ||||
|             <version>${commons.lang.version}</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>com.squareup</groupId> | ||||
|             <artifactId>javapoet</artifactId> | ||||
|             <version>1.11.1</version> | ||||
|         </dependency> | ||||
| 
 | ||||
| 
 | ||||
|         <!-- Test Dependencies --> | ||||
|         <dependency> | ||||
|             <groupId>org.junit.jupiter</groupId> | ||||
|             <artifactId>junit-jupiter-api</artifactId> | ||||
|             <version>${junit-jupiter.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.junit.jupiter</groupId> | ||||
|             <artifactId>junit-jupiter-engine</artifactId> | ||||
|             <version>${junit-jupiter.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>org.jetbrains.kotlin</groupId> | ||||
|             <artifactId>kotlin-stdlib-jdk8</artifactId> | ||||
|             <version>${kotlin.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.jetbrains.kotlin</groupId> | ||||
|             <artifactId>kotlin-test</artifactId> | ||||
|             <version>${kotlin.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>com.fasterxml.jackson.module</groupId> | ||||
|             <artifactId>jackson-module-kotlin</artifactId> | ||||
|             <version>2.9.9</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|             <groupId>com.beust</groupId> | ||||
|             <artifactId>jcommander</artifactId> | ||||
|             <version>1.78</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|             <artifactId>nd4j-api</artifactId> | ||||
|             <version>1.0.0-SNAPSHOT</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
| 
 | ||||
|     <build> | ||||
|         <plugins> | ||||
|             <plugin> | ||||
|                 <groupId>org.codehaus.mojo</groupId> | ||||
|                 <artifactId>build-helper-maven-plugin</artifactId> | ||||
|                 <version>3.0.0</version> | ||||
|                 <executions> | ||||
|                     <execution> | ||||
|                         <id>add-source</id> | ||||
|                         <phase>generate-sources</phase> | ||||
|                         <goals><goal>add-source</goal></goals> | ||||
|                         <configuration> | ||||
|                             <sources> | ||||
|                                 <source>src/main/stubs</source> | ||||
|                             </sources> | ||||
|                         </configuration> | ||||
|                     </execution> | ||||
|                 </executions> | ||||
|             </plugin> | ||||
| 
 | ||||
|             <plugin> | ||||
|                 <groupId>org.apache.maven.plugins</groupId> | ||||
|                 <artifactId>maven-shade-plugin</artifactId> | ||||
|                 <version>${maven-shade-plugin.version}</version> | ||||
|                 <configuration> | ||||
|                     <shadedArtifactAttached>true</shadedArtifactAttached> | ||||
|                     <createDependencyReducedPom>false</createDependencyReducedPom> | ||||
|                     <filters> | ||||
|                         <filter> | ||||
|                             <artifact>*:*</artifact> | ||||
|                             <excludes> | ||||
|                                 <exclude>org/datanucleus/**</exclude> | ||||
|                                 <exclude>META-INF/*.SF</exclude> | ||||
|                                 <exclude>META-INF/*.DSA</exclude> | ||||
|                                 <exclude>META-INF/*.RSA</exclude> | ||||
|                             </excludes> | ||||
|                         </filter> | ||||
|                     </filters> | ||||
| 
 | ||||
|                 </configuration> | ||||
| 
 | ||||
|                 <executions> | ||||
|                     <execution> | ||||
|                         <phase>package</phase> | ||||
|                         <goals> | ||||
|                             <goal>shade</goal> | ||||
|                         </goals> | ||||
|                         <configuration> | ||||
|                             <transformers> | ||||
|                                 <transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer"> | ||||
|                                     <resource>reference.conf</resource> | ||||
|                                 </transformer> | ||||
|                                 <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> | ||||
|                                 <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" /> | ||||
|                             </transformers> | ||||
|                         </configuration> | ||||
|                     </execution> | ||||
|                 </executions> | ||||
|             </plugin> | ||||
| 
 | ||||
|             <plugin> | ||||
|                 <groupId>org.jetbrains.kotlin</groupId> | ||||
|                 <artifactId>kotlin-maven-plugin</artifactId> | ||||
|                 <version>${kotlin.version}</version> | ||||
|                 <configuration> | ||||
|                     <args> | ||||
|                         <arg>-Xjsr305=strict</arg> | ||||
|                     </args> | ||||
|                     <compilerPlugins> | ||||
|                         <plugin>spring</plugin> | ||||
|                         <plugin>jpa</plugin> | ||||
|                     </compilerPlugins> | ||||
|                 </configuration> | ||||
|                 <dependencies> | ||||
|                     <dependency> | ||||
|                         <groupId>org.jetbrains.kotlin</groupId> | ||||
|                         <artifactId>kotlin-maven-allopen</artifactId> | ||||
|                         <version>${kotlin.version}</version> | ||||
|                     </dependency> | ||||
|                     <dependency> | ||||
|                         <groupId>org.jetbrains.kotlin</groupId> | ||||
|                         <artifactId>kotlin-maven-noarg</artifactId> | ||||
|                         <version>${kotlin.version}</version> | ||||
|                     </dependency> | ||||
|                 </dependencies> | ||||
|                 <executions> | ||||
|                     <execution> | ||||
|                         <id>compile</id> | ||||
|                         <goals> <goal>compile</goal> </goals> | ||||
|                         <configuration> | ||||
|                             <sourceDirs> | ||||
|                                 <sourceDir>${project.basedir}/src/main/stubs</sourceDir> | ||||
|                                 <sourceDir>${project.basedir}/src/main/kotlin</sourceDir> | ||||
|                                 <sourceDir>${project.basedir}/src/main/java</sourceDir> | ||||
|                                 <sourceDir>${project.basedir}/src/main/ops</sourceDir> | ||||
|                             </sourceDirs> | ||||
|                         </configuration> | ||||
|                     </execution> | ||||
|                     <execution> | ||||
|                         <id>test-compile</id> | ||||
|                         <goals> <goal>test-compile</goal> </goals> | ||||
|                         <configuration> | ||||
|                             <sourceDirs> | ||||
|                                 <sourceDir>${project.basedir}/src/test/stubs</sourceDir> | ||||
|                                 <sourceDir>${project.basedir}/src/test/kotlin</sourceDir> | ||||
|                                 <sourceDir>${project.basedir}/src/test/java</sourceDir> | ||||
|                                 <sourceDir>${project.basedir}/src/test/ops</sourceDir> | ||||
|                             </sourceDirs> | ||||
|                         </configuration> | ||||
|                     </execution> | ||||
|                 </executions> | ||||
|             </plugin> | ||||
| 
 | ||||
|             <!-- https://kotlinlang.org/docs/reference/using-maven.html --> | ||||
|             <plugin> | ||||
|                 <groupId>org.apache.maven.plugins</groupId> | ||||
|                 <artifactId>maven-compiler-plugin</artifactId> | ||||
|                 <version>3.5.1</version> | ||||
|                 <executions> | ||||
|                     <!-- Replacing default-compile as it is treated specially by maven --> | ||||
|                     <execution> | ||||
|                         <id>default-compile</id> | ||||
|                         <phase>none</phase> | ||||
|                     </execution> | ||||
|                     <!-- Replacing default-testCompile as it is treated specially by maven --> | ||||
|                     <execution> | ||||
|                         <id>default-testCompile</id> | ||||
|                         <phase>none</phase> | ||||
|                     </execution> | ||||
|                     <execution> | ||||
|                         <id>java-compile</id> | ||||
|                         <phase>compile</phase> | ||||
|                         <goals> <goal>compile</goal> </goals> | ||||
|                     </execution> | ||||
|                     <execution> | ||||
|                         <id>java-test-compile</id> | ||||
|                         <phase>test-compile</phase> | ||||
|                         <goals> <goal>testCompile</goal> </goals> | ||||
|                     </execution> | ||||
|                 </executions> | ||||
|                 <configuration> | ||||
|                     <source>${java.version}</source> | ||||
|                     <target>${java.version}</target> | ||||
|                 </configuration> | ||||
|             </plugin> | ||||
|         </plugins> | ||||
|     </build> | ||||
| 
 | ||||
| </project> | ||||
| @ -19,17 +19,13 @@ | ||||
|  */ | ||||
| 
 | ||||
| package org.nd4j.codegen.impl.java; | ||||
| 
 | ||||
| import org.apache.commons.lang3.StringUtils; | ||||
| import org.nd4j.codegen.api.Language; | ||||
| import org.nd4j.codegen.api.Namespace; | ||||
| import org.nd4j.codegen.api.NamespaceOps; | ||||
| import org.nd4j.codegen.api.Op; | ||||
| import org.nd4j.codegen.api.generator.Generator; | ||||
| import org.nd4j.codegen.api.generator.GeneratorConfig; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| public class JavaPoetGenerator implements Generator { | ||||
| 
 | ||||
| @ -40,12 +36,12 @@ public class JavaPoetGenerator implements Generator { | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException { | ||||
|     public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException { | ||||
|         Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.linalg.factory", StringUtils.EMPTY); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException { | ||||
|     public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException { | ||||
|         //throw new UnsupportedOperationException("Not yet implemented"); | ||||
|         Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.autodiff.samediff", StringUtils.EMPTY); | ||||
|     } | ||||
|  | ||||
| @ -47,6 +47,23 @@ fun SDLoss() =  Namespace("Loss"){ | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     Op("ctcLoss") { | ||||
|         javaPackage = "org.nd4j.linalg.api.ops.impl.loss" | ||||
|         javaOpClass = "CtcLoss" | ||||
|         Input(NUMERIC, "targetLabels") { description = "Label array" } | ||||
|         Input(NUMERIC, "logitInput") { description = "Inputs" } | ||||
|         Input(NUMERIC, "targetLabelLengths") { description = "Length of the target label" } | ||||
|         Input(NUMERIC, "logitInputLengths") { description = "Length of the input"} | ||||
|         Output(NUMERIC, "output"){ description = "Ctc loss " } | ||||
|         Doc(Language.ANY, DocScope.ALL){ | ||||
|             """ | ||||
|                 CTC Loss: Connectionist Temporal Classification Loss. See: | ||||
|                 https://dl.acm.org/citation.cfm?id=1143891 | ||||
|             """.trimIndent() | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Op("cosineDistance") { | ||||
|         javaPackage = "org.nd4j.linalg.api.ops.impl.loss" | ||||
|         javaOpClass = "CosineDistanceLoss" | ||||
|  | ||||
| @ -1,32 +1,26 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import java.util.Arrays; | ||||
| import java.util.concurrent.atomic.AtomicInteger; | ||||
| import java.util.stream.Collectors; | ||||
| 
 | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| @ -307,7 +301,7 @@ public class SDBaseOps { | ||||
|    * @param transposeB Whether to transpose B arrays or not | ||||
|    */ | ||||
|   public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable[] inputsB, boolean transposeA, | ||||
|                                 boolean transposeB) { | ||||
|       boolean transposeB) { | ||||
|     SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); | ||||
|     Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); | ||||
|     SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); | ||||
| @ -331,7 +325,7 @@ public class SDBaseOps { | ||||
|    * @param transposeB Whether to transpose B arrays or not | ||||
|    */ | ||||
|   public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable[] inputsB, | ||||
|                                 boolean transposeA, boolean transposeB) { | ||||
|       boolean transposeA, boolean transposeB) { | ||||
|     SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); | ||||
|     Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); | ||||
|     SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); | ||||
| @ -482,7 +476,7 @@ public class SDBaseOps { | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, | ||||
|                             int... axis) { | ||||
|       int... axis) { | ||||
|     SDValidation.validateNumerical("cumprod", "in", in); | ||||
|     Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable(); | ||||
| @ -563,7 +557,7 @@ public class SDBaseOps { | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, | ||||
|                            int... axis) { | ||||
|       int... axis) { | ||||
|     SDValidation.validateNumerical("cumsum", "in", in); | ||||
|     Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable(); | ||||
| @ -680,7 +674,7 @@ public class SDBaseOps { | ||||
|    * @param numPartitions Number of partitions, >= 1 | ||||
|    */ | ||||
|   public SDVariable[] dynamicPartition(String[] names, SDVariable x, SDVariable partitions, | ||||
|                                        int numPartitions) { | ||||
|       int numPartitions) { | ||||
|     SDValidation.validateNumerical("dynamicPartition", "x", x); | ||||
|     SDValidation.validateInteger("dynamicPartition", "partitions", partitions); | ||||
|     SDVariable[] out =  new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables(); | ||||
| @ -1189,7 +1183,7 @@ public class SDBaseOps { | ||||
|    * @return output INDArray  with linearly spaced elements (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable linspace(String name, DataType dataType, double start, double stop, | ||||
|                              long number) { | ||||
|       long number) { | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable(); | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
|   } | ||||
| @ -1205,7 +1199,7 @@ public class SDBaseOps { | ||||
|    * @return output INDArray  with linearly spaced elements (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable linspace(SDVariable start, SDVariable stop, SDVariable number, | ||||
|                              DataType dataType) { | ||||
|       DataType dataType) { | ||||
|     SDValidation.validateNumerical("linspace", "start", start); | ||||
|     SDValidation.validateNumerical("linspace", "stop", stop); | ||||
|     SDValidation.validateInteger("linspace", "number", number); | ||||
| @ -1224,7 +1218,7 @@ public class SDBaseOps { | ||||
|    * @return output INDArray  with linearly spaced elements (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable linspace(String name, SDVariable start, SDVariable stop, SDVariable number, | ||||
|                              DataType dataType) { | ||||
|       DataType dataType) { | ||||
|     SDValidation.validateNumerical("linspace", "start", start); | ||||
|     SDValidation.validateNumerical("linspace", "stop", stop); | ||||
|     SDValidation.validateInteger("linspace", "number", number); | ||||
| @ -1445,7 +1439,7 @@ public class SDBaseOps { | ||||
|    * @return output Number of elements that the condition is satisfied for (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDim, | ||||
|                                         int... dimensions) { | ||||
|       int... dimensions) { | ||||
|     SDValidation.validateNumerical("matchConditionCount", "in", in); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable(); | ||||
| @ -1469,7 +1463,7 @@ public class SDBaseOps { | ||||
|    * @return output Number of elements that the condition is satisfied for (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, | ||||
|                                         boolean keepDim, int... dimensions) { | ||||
|       boolean keepDim, int... dimensions) { | ||||
|     SDValidation.validateNumerical("matchConditionCount", "in", in); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable(); | ||||
| @ -1514,7 +1508,7 @@ public class SDBaseOps { | ||||
|    * @return output Number of elements that the condition is satisfied for (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, | ||||
|                                         int... dimensions) { | ||||
|       int... dimensions) { | ||||
|     SDValidation.validateNumerical("matchConditionCount", "in", in); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable(); | ||||
| @ -1895,7 +1889,7 @@ public class SDBaseOps { | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, | ||||
|                          boolean transposeZ) { | ||||
|       boolean transposeZ) { | ||||
|     SDValidation.validateNumerical("mmul", "x", x); | ||||
|     SDValidation.validateNumerical("mmul", "y", y); | ||||
|     return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); | ||||
| @ -1914,7 +1908,7 @@ public class SDBaseOps { | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX, | ||||
|                          boolean transposeY, boolean transposeZ) { | ||||
|       boolean transposeY, boolean transposeZ) { | ||||
|     SDValidation.validateNumerical("mmul", "x", x); | ||||
|     SDValidation.validateNumerical("mmul", "y", y); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); | ||||
| @ -2304,14 +2298,14 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param indices Indices - value 0 to depth-1 (NUMERIC type) | ||||
|    * @param depth Number of classes | ||||
|    * @param axis | ||||
|    * @param on | ||||
|    * @param off | ||||
|    * @param axis  | ||||
|    * @param on  | ||||
|    * @param off  | ||||
|    * @param dataType Output data type | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off, | ||||
|                            DataType dataType) { | ||||
|       DataType dataType) { | ||||
|     SDValidation.validateNumerical("oneHot", "indices", indices); | ||||
|     return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable(); | ||||
|   } | ||||
| @ -2324,14 +2318,14 @@ public class SDBaseOps { | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param indices Indices - value 0 to depth-1 (NUMERIC type) | ||||
|    * @param depth Number of classes | ||||
|    * @param axis | ||||
|    * @param on | ||||
|    * @param off | ||||
|    * @param axis  | ||||
|    * @param on  | ||||
|    * @param off  | ||||
|    * @param dataType Output data type | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, | ||||
|                            double off, DataType dataType) { | ||||
|       double off, DataType dataType) { | ||||
|     SDValidation.validateNumerical("oneHot", "indices", indices); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable(); | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
| @ -2344,9 +2338,9 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param indices Indices - value 0 to depth-1 (NUMERIC type) | ||||
|    * @param depth Number of classes | ||||
|    * @param axis | ||||
|    * @param on | ||||
|    * @param off | ||||
|    * @param axis  | ||||
|    * @param on  | ||||
|    * @param off  | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) { | ||||
| @ -2362,13 +2356,13 @@ public class SDBaseOps { | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param indices Indices - value 0 to depth-1 (NUMERIC type) | ||||
|    * @param depth Number of classes | ||||
|    * @param axis | ||||
|    * @param on | ||||
|    * @param off | ||||
|    * @param axis  | ||||
|    * @param on  | ||||
|    * @param off  | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, | ||||
|                            double off) { | ||||
|       double off) { | ||||
|     SDValidation.validateNumerical("oneHot", "indices", indices); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable(); | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
| @ -2436,7 +2430,7 @@ public class SDBaseOps { | ||||
|    * As per onesLike(String, SDVariable) but the output datatype may be specified<br> | ||||
|    * | ||||
|    * @param input  (NUMERIC type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable onesLike(SDVariable input, DataType dataType) { | ||||
| @ -2449,7 +2443,7 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input  (NUMERIC type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable onesLike(String name, SDVariable input, DataType dataType) { | ||||
| @ -2612,7 +2606,7 @@ public class SDBaseOps { | ||||
|    * @param from Initial/smallest value | ||||
|    * @param to Largest value (exclusive) | ||||
|    * @param step Step size | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output INDArray  with the specified values (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable range(double from, double to, double step, DataType dataType) { | ||||
| @ -2628,7 +2622,7 @@ public class SDBaseOps { | ||||
|    * @param from Initial/smallest value | ||||
|    * @param to Largest value (exclusive) | ||||
|    * @param step Step size | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output INDArray  with the specified values (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable range(String name, double from, double to, double step, DataType dataType) { | ||||
| @ -2644,7 +2638,7 @@ public class SDBaseOps { | ||||
|    * @param from Initial/smallest value (NUMERIC type) | ||||
|    * @param to Largest value (exclusive) (NUMERIC type) | ||||
|    * @param step Step size (NUMERIC type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output INDArray  with the specified values (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { | ||||
| @ -2663,11 +2657,11 @@ public class SDBaseOps { | ||||
|    * @param from Initial/smallest value (NUMERIC type) | ||||
|    * @param to Largest value (exclusive) (NUMERIC type) | ||||
|    * @param step Step size (NUMERIC type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output INDArray  with the specified values (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step, | ||||
|                           DataType dataType) { | ||||
|       DataType dataType) { | ||||
|     SDValidation.validateNumerical("range", "from", from); | ||||
|     SDValidation.validateNumerical("range", "to", to); | ||||
|     SDValidation.validateNumerical("range", "step", step); | ||||
| @ -2727,7 +2721,7 @@ public class SDBaseOps { | ||||
|    * @return output New array with values replaced where condition is satisfied (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, | ||||
|                                  Condition condition) { | ||||
|       Condition condition) { | ||||
|     SDValidation.validateNumerical("replaceWhere", "update", update); | ||||
|     SDValidation.validateNumerical("replaceWhere", "from", from); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable(); | ||||
| @ -2761,7 +2755,7 @@ public class SDBaseOps { | ||||
|    * @return output New array with values replaced where condition is satisfied (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable replaceWhere(String name, SDVariable update, double value, | ||||
|                                  Condition condition) { | ||||
|       Condition condition) { | ||||
|     SDValidation.validateNumerical("replaceWhere", "update", update); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable(); | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
| @ -2799,47 +2793,6 @@ public class SDBaseOps { | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
|   } | ||||
| 
 | ||||
| 
 | ||||
|   /** | ||||
|    * Split the input in to a list of sub tensors | ||||
|    * @param input the input to split | ||||
|    * @param numSizeSplits the number of splits | ||||
|    * @param splitDim the dimension to split along | ||||
|    * @return the set of output variables | ||||
|    */ | ||||
|   public SDVariable[] split(SDVariable input,int numSizeSplits,int splitDim) { | ||||
|     SDValidation.validateNumerical("split",input); | ||||
|     SDVariable[] out =  new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables(); | ||||
|     return out; | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Split the input in to a list of sub tensors | ||||
|    * @param name the potential name of the input | ||||
|    * @param input the input to split | ||||
|    * @param numSizeSplits the number of splits | ||||
|    * @param splitDim the dimension to split along | ||||
|    * @return the set of output variables | ||||
|    */ | ||||
|   public SDVariable[] split(String name,SDVariable input,int numSizeSplits,int splitDim) { | ||||
|     SDValidation.validateNumerical("split",input); | ||||
|     SDVariable[] out =  new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables(); | ||||
|     SDVariable[] ret = new SDVariable[out.length]; | ||||
|     AtomicInteger index = new AtomicInteger(0); | ||||
|     Arrays.stream(out).forEach(output -> { | ||||
|       if(index.get() < 1) { | ||||
|         ret[index.get()] = sd.updateVariableNameAndReference(output,name); | ||||
|         index.incrementAndGet(); | ||||
|       } | ||||
|       else { | ||||
|         ret[index.get()] = sd.updateVariableNameAndReference(output,name + ":" + index.get()); | ||||
|         index.incrementAndGet(); | ||||
|       } | ||||
|     }); | ||||
| 
 | ||||
|     return ret; | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br> | ||||
|    * input, but with the specified shape.<br> | ||||
| @ -2930,7 +2883,7 @@ public class SDBaseOps { | ||||
|    * @return output Reversed sequences (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, | ||||
|                                     int batchDim) { | ||||
|       int batchDim) { | ||||
|     SDValidation.validateNumerical("reverseSequence", "x", x); | ||||
|     SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); | ||||
|     return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable(); | ||||
| @ -2947,7 +2900,7 @@ public class SDBaseOps { | ||||
|    * @return output Reversed sequences (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, | ||||
|                                     int batchDim) { | ||||
|       int batchDim) { | ||||
|     SDValidation.validateNumerical("reverseSequence", "x", x); | ||||
|     SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable(); | ||||
| @ -3123,7 +3076,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, | ||||
|                                SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterAdd", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterAdd", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterAdd", "updates", updates); | ||||
| @ -3166,7 +3119,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, | ||||
|                                SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterDiv", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterDiv", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterDiv", "updates", updates); | ||||
| @ -3209,7 +3162,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, | ||||
|                                SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterMax", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterMax", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterMax", "updates", updates); | ||||
| @ -3252,7 +3205,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, | ||||
|                                SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterMin", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterMin", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterMin", "updates", updates); | ||||
| @ -3295,7 +3248,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, | ||||
|                                SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterMul", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterMul", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterMul", "updates", updates); | ||||
| @ -3338,7 +3291,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, | ||||
|                                SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterSub", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterSub", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterSub", "updates", updates); | ||||
| @ -3381,7 +3334,7 @@ public class SDBaseOps { | ||||
|    * @return output The updated variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, | ||||
|                                   SDVariable updates) { | ||||
|       SDVariable updates) { | ||||
|     SDValidation.validateNumerical("scatterUpdate", "ref", ref); | ||||
|     SDValidation.validateNumerical("scatterUpdate", "indices", indices); | ||||
|     SDValidation.validateNumerical("scatterUpdate", "updates", updates); | ||||
| @ -3595,7 +3548,7 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param lengths Lengths of the sequences (NUMERIC type) | ||||
|    * @param maxLen Maximum sequence length | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { | ||||
| @ -3610,7 +3563,7 @@ public class SDBaseOps { | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param lengths Lengths of the sequences (NUMERIC type) | ||||
|    * @param maxLen Maximum sequence length | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) { | ||||
| @ -3625,7 +3578,7 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param lengths Lengths of the sequences (NUMERIC type) | ||||
|    * @param maxLen Maximum sequence length (INT type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { | ||||
| @ -3641,11 +3594,11 @@ public class SDBaseOps { | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param lengths Lengths of the sequences (NUMERIC type) | ||||
|    * @param maxLen Maximum sequence length (INT type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen, | ||||
|                                  DataType dataType) { | ||||
|       DataType dataType) { | ||||
|     SDValidation.validateNumerical("sequenceMask", "lengths", lengths); | ||||
|     SDValidation.validateInteger("sequenceMask", "maxLen", maxLen); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); | ||||
| @ -3656,7 +3609,7 @@ public class SDBaseOps { | ||||
|    * see sequenceMask(String, SDVariable, SDVariable, DataType)<br> | ||||
|    * | ||||
|    * @param lengths  (NUMERIC type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { | ||||
| @ -3669,7 +3622,7 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param lengths  (NUMERIC type) | ||||
|    * @param dataType | ||||
|    * @param dataType  | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) { | ||||
| @ -3857,7 +3810,7 @@ public class SDBaseOps { | ||||
|    * keepDims = false: [a,c]<br> | ||||
|    * | ||||
|    * @param x  (NUMERIC type) | ||||
|    * @param keepDims | ||||
|    * @param keepDims  | ||||
|    * @param dimensions  (Size: AtLeast(min=0)) | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
| @ -3879,7 +3832,7 @@ public class SDBaseOps { | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param x  (NUMERIC type) | ||||
|    * @param keepDims | ||||
|    * @param keepDims  | ||||
|    * @param dimensions  (Size: AtLeast(min=0)) | ||||
|    * @return output  (NUMERIC type) | ||||
|    */ | ||||
| @ -4015,7 +3968,7 @@ public class SDBaseOps { | ||||
|    * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, boolean keepDims, | ||||
|                                       int... dimensions) { | ||||
|       int... dimensions) { | ||||
|     SDValidation.validateNumerical("standardDeviation", "x", x); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); | ||||
| @ -4039,7 +3992,7 @@ public class SDBaseOps { | ||||
|    * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, | ||||
|                                       boolean keepDims, int... dimensions) { | ||||
|       boolean keepDims, int... dimensions) { | ||||
|     SDValidation.validateNumerical("standardDeviation", "x", x); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); | ||||
| @ -4084,7 +4037,7 @@ public class SDBaseOps { | ||||
|    * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, | ||||
|                                       int... dimensions) { | ||||
|       int... dimensions) { | ||||
|     SDValidation.validateNumerical("standardDeviation", "x", x); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable(); | ||||
| @ -4113,7 +4066,7 @@ public class SDBaseOps { | ||||
|    * @return output A subset of the input array (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, | ||||
|                                  int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { | ||||
|       int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { | ||||
|     SDValidation.validateNumerical("stridedSlice", "in", in); | ||||
|     Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); | ||||
|     Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); | ||||
| @ -4144,8 +4097,8 @@ public class SDBaseOps { | ||||
|    * @return output A subset of the input array (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, | ||||
|                                  long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, | ||||
|                                  int shrinkAxisMask) { | ||||
|       long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, | ||||
|       int shrinkAxisMask) { | ||||
|     SDValidation.validateNumerical("stridedSlice", "in", in); | ||||
|     Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); | ||||
|     Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); | ||||
| @ -4196,7 +4149,7 @@ public class SDBaseOps { | ||||
|    * @return output A subset of the input array (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, | ||||
|                                  long... strides) { | ||||
|       long... strides) { | ||||
|     SDValidation.validateNumerical("stridedSlice", "in", in); | ||||
|     Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); | ||||
|     Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); | ||||
| @ -4330,7 +4283,7 @@ public class SDBaseOps { | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int[] dimensionsY, | ||||
|                                boolean transposeX, boolean transposeY, boolean transposeZ) { | ||||
|       boolean transposeX, boolean transposeY, boolean transposeZ) { | ||||
|     SDValidation.validateNumerical("tensorMmul", "x", x); | ||||
|     SDValidation.validateNumerical("tensorMmul", "y", y); | ||||
|     Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); | ||||
| @ -4352,7 +4305,7 @@ public class SDBaseOps { | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX, | ||||
|                                int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { | ||||
|       int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { | ||||
|     SDValidation.validateNumerical("tensorMmul", "x", x); | ||||
|     SDValidation.validateNumerical("tensorMmul", "y", y); | ||||
|     Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); | ||||
| @ -4389,7 +4342,7 @@ public class SDBaseOps { | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX, | ||||
|                                int... dimensionsY) { | ||||
|       int... dimensionsY) { | ||||
|     SDValidation.validateNumerical("tensorMmul", "x", x); | ||||
|     SDValidation.validateNumerical("tensorMmul", "y", y); | ||||
|     Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); | ||||
| @ -4522,7 +4475,7 @@ public class SDBaseOps { | ||||
|    * @return output Unsorted segment output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds, | ||||
|                                        int numSegments) { | ||||
|       int numSegments) { | ||||
|     SDValidation.validateNumerical("unsortedSegmentMax", "data", data); | ||||
|     SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable(); | ||||
| @ -4561,7 +4514,7 @@ public class SDBaseOps { | ||||
|    * @return output Unsorted segment output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds, | ||||
|                                         int numSegments) { | ||||
|       int numSegments) { | ||||
|     SDValidation.validateNumerical("unsortedSegmentMean", "data", data); | ||||
|     SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable(); | ||||
| @ -4600,7 +4553,7 @@ public class SDBaseOps { | ||||
|    * @return output Unsorted segment output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds, | ||||
|                                        int numSegments) { | ||||
|       int numSegments) { | ||||
|     SDValidation.validateNumerical("unsortedSegmentMin", "data", data); | ||||
|     SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable(); | ||||
| @ -4639,7 +4592,7 @@ public class SDBaseOps { | ||||
|    * @return output Unsorted segment output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds, | ||||
|                                         int numSegments) { | ||||
|       int numSegments) { | ||||
|     SDValidation.validateNumerical("unsortedSegmentProd", "data", data); | ||||
|     SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable(); | ||||
| @ -4676,7 +4629,7 @@ public class SDBaseOps { | ||||
|    * @return output Unsorted segment output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds, | ||||
|                                          int numSegments) { | ||||
|       int numSegments) { | ||||
|     SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); | ||||
|     SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable(); | ||||
| @ -4715,7 +4668,7 @@ public class SDBaseOps { | ||||
|    * @return output Unsorted segment output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable unsortedSegmentSum(String name, SDVariable data, SDVariable segmentIds, | ||||
|                                        int numSegments) { | ||||
|       int numSegments) { | ||||
|     SDValidation.validateNumerical("unsortedSegmentSum", "data", data); | ||||
|     SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable(); | ||||
| @ -4771,7 +4724,7 @@ public class SDBaseOps { | ||||
|    * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable variance(SDVariable x, boolean biasCorrected, boolean keepDims, | ||||
|                              int... dimensions) { | ||||
|       int... dimensions) { | ||||
|     SDValidation.validateNumerical("variance", "x", x); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); | ||||
| @ -4795,7 +4748,7 @@ public class SDBaseOps { | ||||
|    * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims, | ||||
|                              int... dimensions) { | ||||
|       int... dimensions) { | ||||
|     SDValidation.validateNumerical("variance", "x", x); | ||||
|     Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); | ||||
|  | ||||
| @ -1,22 +1,20 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| @ -250,7 +248,7 @@ public class SDBitwise extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Bitwise left cyclical shift operation. Supports broadcasting.<br> | ||||
|    * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br> | ||||
|    * Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br> | ||||
|    * {@code leftShiftCyclic(01110000, 2) -> 11000001}<br> | ||||
|    * | ||||
|    * @param x Input to be bit shifted (INT type) | ||||
| @ -265,7 +263,7 @@ public class SDBitwise extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Bitwise left cyclical shift operation. Supports broadcasting.<br> | ||||
|    * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br> | ||||
|    * Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br> | ||||
|    * {@code leftShiftCyclic(01110000, 2) -> 11000001}<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
| @ -348,7 +346,7 @@ public class SDBitwise extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Bitwise right cyclical shift operation. Supports broadcasting.<br> | ||||
|    * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br> | ||||
|    * Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br> | ||||
|    * {@code rightShiftCyclic(00001110, 2) -> 10000011}<br> | ||||
|    * | ||||
|    * @param x Input to be bit shifted (INT type) | ||||
| @ -363,7 +361,7 @@ public class SDBitwise extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Bitwise right cyclical shift operation. Supports broadcasting.<br> | ||||
|    * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br> | ||||
|    * Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br> | ||||
|    * {@code rightShiftCyclic(00001110, 2) -> 10000011}<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| @ -42,8 +42,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 2D Convolution layer operation - average pooling 2d<br> | ||||
|    * | ||||
|    * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    * @return output Result after applying average pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -56,8 +55,7 @@ public class SDCNN extends SDOps { | ||||
|    * 2D Convolution layer operation - average pooling 2d<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    * @return output Result after applying average pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -70,9 +68,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 3D convolution layer operation - average pooling 3d <br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    *                         (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    *                         (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling3DConfig Configuration Object | ||||
|    * @return output after applying average pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -85,9 +81,7 @@ public class SDCNN extends SDOps { | ||||
|    * 3D convolution layer operation - average pooling 3d <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    *                         (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    *                         (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling3DConfig Configuration Object | ||||
|    * @return output after applying average pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -302,9 +296,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * Convolution 3D operation with optional bias <br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    * (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights  Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) | ||||
|    * @param bias  Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) | ||||
|    * @param Conv3DConfig Configuration Object | ||||
| @ -322,9 +314,7 @@ public class SDCNN extends SDOps { | ||||
|    * Convolution 3D operation with optional bias <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    * (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights  Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) | ||||
|    * @param bias  Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) | ||||
|    * @param Conv3DConfig Configuration Object | ||||
| @ -342,9 +332,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * Convolution 3D operation with optional bias <br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    * (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights  Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) | ||||
|    * @param Conv3DConfig Configuration Object | ||||
|    * @return output Conv3d output variable (NUMERIC type) | ||||
| @ -359,9 +347,7 @@ public class SDCNN extends SDOps { | ||||
|    * Convolution 3D operation with optional bias <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    * (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights  Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) | ||||
|    * @param Conv3DConfig Configuration Object | ||||
|    * @return output Conv3d output variable (NUMERIC type) | ||||
| @ -377,8 +363,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 2D deconvolution operation with optional bias<br> | ||||
|    * | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) | ||||
|    * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) | ||||
|    * @param DeConv2DConfig Configuration Object | ||||
| @ -396,8 +381,7 @@ public class SDCNN extends SDOps { | ||||
|    * 2D deconvolution operation with optional bias<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) | ||||
|    * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) | ||||
|    * @param DeConv2DConfig Configuration Object | ||||
| @ -415,8 +399,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 2D deconvolution operation with optional bias<br> | ||||
|    * | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) | ||||
|    * @param DeConv2DConfig Configuration Object | ||||
|    * @return output result of deconv2d op (NUMERIC type) | ||||
| @ -432,8 +415,7 @@ public class SDCNN extends SDOps { | ||||
|    * 2D deconvolution operation with optional bias<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) | ||||
|    * @param DeConv2DConfig Configuration Object | ||||
|    * @return output result of deconv2d op (NUMERIC type) | ||||
| @ -519,8 +501,7 @@ public class SDCNN extends SDOps { | ||||
|    * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br> | ||||
|    * = [mb, 2, 4, 4]<br> | ||||
|    * | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format | ||||
|    *                    (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param blockSize Block size, in the height/width dimension | ||||
|    * @param dataFormat Data format: "NCHW" or "NHWC" | ||||
|    * @return output Output variable (NUMERIC type) | ||||
| @ -537,8 +518,7 @@ public class SDCNN extends SDOps { | ||||
|    * = [mb, 2, 4, 4]<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format | ||||
|    *                    (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param blockSize Block size, in the height/width dimension | ||||
|    * @param dataFormat Data format: "NCHW" or "NHWC" | ||||
|    * @return output Output variable (NUMERIC type) | ||||
| @ -756,8 +736,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br> | ||||
|    * | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    */ | ||||
|   public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) { | ||||
| @ -769,8 +748,7 @@ public class SDCNN extends SDOps { | ||||
|    * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br> | ||||
|    * | ||||
|    * @param names names May be null. Arrays of names for the output variables. | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    */ | ||||
|   public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input, | ||||
| @ -783,8 +761,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 2D Convolution layer operation - max pooling 2d <br> | ||||
|    * | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    * @return output Result after applying max pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -797,8 +774,7 @@ public class SDCNN extends SDOps { | ||||
|    * 2D Convolution layer operation - max pooling 2d <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    * @return output Result after applying max pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -811,9 +787,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * 3D convolution layer operation - max pooling 3d operation.<br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    *                         (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    *                         (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling3DConfig Configuration Object | ||||
|    * @return output Result after applying max pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -826,9 +800,7 @@ public class SDCNN extends SDOps { | ||||
|    * 3D convolution layer operation - max pooling 3d operation.<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    *                         (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    *                         (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling3DConfig Configuration Object | ||||
|    * @return output Result after applying max pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -841,8 +813,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * Separable 2D convolution operation with optional bias <br> | ||||
|    * | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                      (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) | ||||
|    * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) | ||||
|    * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) | ||||
| @ -862,8 +833,7 @@ public class SDCNN extends SDOps { | ||||
|    * Separable 2D convolution operation with optional bias <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                      (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) | ||||
|    * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) | ||||
|    * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) | ||||
| @ -883,8 +853,7 @@ public class SDCNN extends SDOps { | ||||
|   /** | ||||
|    * Separable 2D convolution operation with optional bias <br> | ||||
|    * | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                      (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) | ||||
|    * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) | ||||
|    * @param Conv2DConfig Configuration Object | ||||
| @ -902,8 +871,7 @@ public class SDCNN extends SDOps { | ||||
|    * Separable 2D convolution operation with optional bias <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                      (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) | ||||
|    * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) | ||||
|    * @param Conv2DConfig Configuration Object | ||||
| @ -964,8 +932,7 @@ public class SDCNN extends SDOps { | ||||
|    * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br> | ||||
|    * = [mb, 2, 4, 4] <br> | ||||
|    * | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format | ||||
|    *                    (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param blockSize  Block size, in the height/width dimension | ||||
|    * @param dataFormat Data format: "NCHW" or "NHWC" | ||||
|    * @return output Output variable (NUMERIC type) | ||||
| @ -982,8 +949,7 @@ public class SDCNN extends SDOps { | ||||
|    * = [mb, 2, 4, 4] <br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format | ||||
|    *                    (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param blockSize  Block size, in the height/width dimension | ||||
|    * @param dataFormat Data format: "NCHW" or "NHWC" | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.loss.LossReduce; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| @ -36,7 +36,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, | ||||
| @ -56,7 +56,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, | ||||
| @ -116,7 +116,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param dimension Dimension to perform the cosine distance over | ||||
|    * @return output Cosine distance loss  (NUMERIC type) | ||||
|    */ | ||||
| @ -141,7 +141,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param dimension Dimension to perform the cosine distance over | ||||
|    * @return output Cosine distance loss  (NUMERIC type) | ||||
|    */ | ||||
| @ -202,6 +202,49 @@ public class SDLoss extends SDOps { | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * CTC Loss: Connectionist Temporal Classification Loss. See:<br> | ||||
|    * https://dl.acm.org/citation.cfm?id=1143891<br> | ||||
|    * | ||||
|    * @param targetLabels Label array (NUMERIC type) | ||||
|    * @param logitInput Inputs (NUMERIC type) | ||||
|    * @param targetLabelLengths Length of the target label (NUMERIC type) | ||||
|    * @param logitInputLengths Length of the input (NUMERIC type) | ||||
|    * @return output Ctc loss  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable ctcLoss(SDVariable targetLabels, SDVariable logitInput, | ||||
|       SDVariable targetLabelLengths, SDVariable logitInputLengths) { | ||||
|     SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels); | ||||
|     SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput); | ||||
|     SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths); | ||||
|     SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths); | ||||
|     SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable(); | ||||
|     out.markAsLoss(); | ||||
|     return out; | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * CTC Loss: Connectionist Temporal Classification Loss. See:<br> | ||||
|    * https://dl.acm.org/citation.cfm?id=1143891<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param targetLabels Label array (NUMERIC type) | ||||
|    * @param logitInput Inputs (NUMERIC type) | ||||
|    * @param targetLabelLengths Length of the target label (NUMERIC type) | ||||
|    * @param logitInputLengths Length of the input (NUMERIC type) | ||||
|    * @return output Ctc loss  (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable ctcLoss(String name, SDVariable targetLabels, SDVariable logitInput, | ||||
|       SDVariable targetLabelLengths, SDVariable logitInputLengths) { | ||||
|     SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels); | ||||
|     SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput); | ||||
|     SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths); | ||||
|     SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths); | ||||
|     SDVariable out =  new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable(); | ||||
|     out.markAsLoss(); | ||||
|     return sd.updateVariableNameAndReference(out, name); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Hinge loss: a loss function used for training classifiers.<br> | ||||
|    * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br> | ||||
| @ -210,7 +253,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights, | ||||
| @ -232,7 +275,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, | ||||
| @ -297,7 +340,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param delta Loss function delta value | ||||
|    * @return output Huber loss (NUMERIC type) | ||||
|    */ | ||||
| @ -324,7 +367,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param delta Loss function delta value | ||||
|    * @return output Huber loss (NUMERIC type) | ||||
|    */ | ||||
| @ -423,7 +466,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param epsilon epsilon | ||||
|    * @return output Log loss  (NUMERIC type) | ||||
|    */ | ||||
| @ -445,7 +488,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param epsilon epsilon | ||||
|    * @return output Log loss  (NUMERIC type) | ||||
|    */ | ||||
| @ -499,7 +542,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) | ||||
|    * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param full Boolean flag. true for logPoissonFull, false for logPoisson | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -521,7 +564,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) | ||||
|    * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param full Boolean flag. true for logPoissonFull, false for logPoisson | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -585,7 +628,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable, scalar output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, | ||||
| @ -608,7 +651,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable, scalar output (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, | ||||
| @ -666,13 +709,13 @@ public class SDLoss extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br> | ||||
|    * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br> | ||||
|    * When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br> | ||||
|    * this is the mean squared error loss function.<br> | ||||
|    * | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, | ||||
| @ -687,14 +730,14 @@ public class SDLoss extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br> | ||||
|    * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br> | ||||
|    * When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br> | ||||
|    * this is the mean squared error loss function.<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, | ||||
| @ -709,7 +752,7 @@ public class SDLoss extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br> | ||||
|    * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br> | ||||
|    * When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br> | ||||
|    * this is the mean squared error loss function.<br> | ||||
|    * | ||||
|    * @param label Label array (NUMERIC type) | ||||
| @ -728,7 +771,7 @@ public class SDLoss extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br> | ||||
|    * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br> | ||||
|    * When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br> | ||||
|    * this is the mean squared error loss function.<br> | ||||
|    * | ||||
|    * @param name name May be null. Name for the output variable | ||||
| @ -764,7 +807,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictionLogits Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param labelSmoothing Label smoothing value. Default value: 0 | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -796,7 +839,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictionLogits Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param labelSmoothing Label smoothing value. Default value: 0 | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -872,7 +915,7 @@ public class SDLoss extends SDOps { | ||||
|   /** | ||||
|    * Applies the softmax activation function to the input, then implement multi-class cross entropy:<br> | ||||
|    * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br> | ||||
|    * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * otherwise, the output is a scalar.<br> | ||||
|    * <p><br> | ||||
|    * When label smoothing is > 0, the following label smoothing is used:<br> | ||||
| @ -884,7 +927,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) | ||||
|    * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param labelSmoothing Label smoothing value. Default value: 0 | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -901,7 +944,7 @@ public class SDLoss extends SDOps { | ||||
|   /** | ||||
|    * Applies the softmax activation function to the input, then implement multi-class cross entropy:<br> | ||||
|    * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br> | ||||
|    * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * otherwise, the output is a scalar.<br> | ||||
|    * <p><br> | ||||
|    * When label smoothing is > 0, the following label smoothing is used:<br> | ||||
| @ -914,7 +957,7 @@ public class SDLoss extends SDOps { | ||||
|    * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) | ||||
|    * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param labelSmoothing Label smoothing value. Default value: 0 | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -932,7 +975,7 @@ public class SDLoss extends SDOps { | ||||
|   /** | ||||
|    * Applies the softmax activation function to the input, then implement multi-class cross entropy:<br> | ||||
|    * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br> | ||||
|    * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * otherwise, the output is a scalar.<br> | ||||
|    * <p><br> | ||||
|    * When label smoothing is > 0, the following label smoothing is used:<br> | ||||
| @ -959,7 +1002,7 @@ public class SDLoss extends SDOps { | ||||
|   /** | ||||
|    * Applies the softmax activation function to the input, then implement multi-class cross entropy:<br> | ||||
|    * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br> | ||||
|    * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * otherwise, the output is a scalar.<br> | ||||
|    * <p><br> | ||||
|    * When label smoothing is > 0, the following label smoothing is used:<br> | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| @ -144,22 +144,22 @@ public class SDRNN extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Long Short-Term Memory layer - Hochreiter 1997.<br> | ||||
|    * SUPPORTS following data formats:\n<br> | ||||
|    * for unidirectional: \n" +<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]\n<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]\n<br> | ||||
|    * SUPPORTS following data formats:<br> | ||||
|    * for unidirectional:<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]<br> | ||||
|    * NTS: shapes [numExamples, timeLength, inOutSize]<br> | ||||
|    * for bidirectional:\n<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br> | ||||
|    * SUPPORTS following direction modes:\n<br> | ||||
|    * for bidirectional:<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br> | ||||
|    * SUPPORTS following direction modes:<br> | ||||
|    * FWD: forward<br> | ||||
|    * BWD: backward<br> | ||||
|    * BIDIR_SUM: bidirectional sum\n<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat\n" +<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br> | ||||
|    * BIDIR_SUM: bidirectional sum<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br> | ||||
|    * You may use different gate configurations:<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br> | ||||
|    * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br> | ||||
|    * | ||||
|    * @param x  Input, with shape dependent on the data format (in config). (NUMERIC type) | ||||
| @ -180,22 +180,22 @@ public class SDRNN extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Long Short-Term Memory layer - Hochreiter 1997.<br> | ||||
|    * SUPPORTS following data formats:\n<br> | ||||
|    * for unidirectional: \n" +<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]\n<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]\n<br> | ||||
|    * SUPPORTS following data formats:<br> | ||||
|    * for unidirectional:<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]<br> | ||||
|    * NTS: shapes [numExamples, timeLength, inOutSize]<br> | ||||
|    * for bidirectional:\n<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br> | ||||
|    * SUPPORTS following direction modes:\n<br> | ||||
|    * for bidirectional:<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br> | ||||
|    * SUPPORTS following direction modes:<br> | ||||
|    * FWD: forward<br> | ||||
|    * BWD: backward<br> | ||||
|    * BIDIR_SUM: bidirectional sum\n<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat\n" +<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br> | ||||
|    * BIDIR_SUM: bidirectional sum<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br> | ||||
|    * You may use different gate configurations:<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br> | ||||
|    * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br> | ||||
|    * | ||||
|    * @param names names May be null. Arrays of names for the output variables. | ||||
| @ -218,22 +218,22 @@ public class SDRNN extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Long Short-Term Memory layer - Hochreiter 1997.<br> | ||||
|    * SUPPORTS following data formats:\n<br> | ||||
|    * for unidirectional: \n" +<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]\n<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]\n<br> | ||||
|    * SUPPORTS following data formats:<br> | ||||
|    * for unidirectional:<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]<br> | ||||
|    * NTS: shapes [numExamples, timeLength, inOutSize]<br> | ||||
|    * for bidirectional:\n<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br> | ||||
|    * SUPPORTS following direction modes:\n<br> | ||||
|    * for bidirectional:<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br> | ||||
|    * SUPPORTS following direction modes:<br> | ||||
|    * FWD: forward<br> | ||||
|    * BWD: backward<br> | ||||
|    * BIDIR_SUM: bidirectional sum\n<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat\n" +<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br> | ||||
|    * BIDIR_SUM: bidirectional sum<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br> | ||||
|    * You may use different gate configurations:<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br> | ||||
|    * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br> | ||||
|    * | ||||
|    * @param x  Input, with shape dependent on the data format (in config). (NUMERIC type) | ||||
| @ -248,22 +248,22 @@ public class SDRNN extends SDOps { | ||||
| 
 | ||||
|   /** | ||||
|    * Long Short-Term Memory layer - Hochreiter 1997.<br> | ||||
|    * SUPPORTS following data formats:\n<br> | ||||
|    * for unidirectional: \n" +<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]\n<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]\n<br> | ||||
|    * SUPPORTS following data formats:<br> | ||||
|    * for unidirectional:<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]<br> | ||||
|    * NTS: shapes [numExamples, timeLength, inOutSize]<br> | ||||
|    * for bidirectional:\n<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br> | ||||
|    * SUPPORTS following direction modes:\n<br> | ||||
|    * for bidirectional:<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br> | ||||
|    * SUPPORTS following direction modes:<br> | ||||
|    * FWD: forward<br> | ||||
|    * BWD: backward<br> | ||||
|    * BIDIR_SUM: bidirectional sum\n<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat\n" +<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br> | ||||
|    * BIDIR_SUM: bidirectional sum<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br> | ||||
|    * You may use different gate configurations:<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br> | ||||
|    * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br> | ||||
|    * | ||||
|    * @param names names May be null. Arrays of names for the output variables. | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff.ops; | ||||
| 
 | ||||
| import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; | ||||
| 
 | ||||
| import java.lang.String; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * Activations */ | ||||
| public enum CellAct { | ||||
|   TANH, | ||||
| 
 | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * Data format: "NCHW" or "NHWC" */ | ||||
| public enum DataFormat { | ||||
|   NCHW, | ||||
| 
 | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * Activations */ | ||||
| public enum GateAct { | ||||
|   TANH, | ||||
| 
 | ||||
|  | ||||
| @ -1,32 +1,43 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. | ||||
|  * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. | ||||
|  * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. | ||||
|  * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. | ||||
|  * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. | ||||
|  * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. | ||||
|  * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */ | ||||
| public enum ImageResizeMethod { | ||||
|   ResizeBilinear, // as java require | ||||
|   ResizeNearest, | ||||
|   ResizeBilinear, | ||||
| 
 | ||||
|   ResizeBicubic, | ||||
|   ResizeArea, | ||||
| 
 | ||||
|   ResizeNearest, | ||||
| 
 | ||||
|   ResizeGaussian, | ||||
|   ResizeLanczos3, | ||||
| 
 | ||||
|   ResizeLanczos5, | ||||
|   ResizeMitchellcubic; | ||||
| 
 | ||||
|   ResizeMitchelcubic, | ||||
| 
 | ||||
|   ResizeArea | ||||
| } | ||||
|  | ||||
| @ -1,25 +1,28 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * for unidirectional:  TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br> | ||||
|  *   NST: shape [numExamples, inOutSize, timeLength]<br> | ||||
|  *   NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br> for bidirectional: | ||||
|  *    T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */ | ||||
| public enum LSTMDataFormat { | ||||
|   TNS, | ||||
| 
 | ||||
|  | ||||
| @ -1,25 +1,30 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * direction <br> | ||||
|  *  FWD: 0 = fwd | ||||
|  *  BWD: 1 = bwd | ||||
|  *  BIDIR_SUM: 2 = bidirectional sum | ||||
|  *  BIDIR_CONCAT: 3 = bidirectional concat | ||||
|  *  BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */ | ||||
| public enum LSTMDirectionMode { | ||||
|   FWD, | ||||
| 
 | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  * Activations */ | ||||
| public enum OutAct { | ||||
|   TANH, | ||||
| 
 | ||||
|  | ||||
| @ -1,22 +1,20 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
|  | ||||
| @ -1,22 +1,20 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
|  | ||||
| @ -1,25 +1,28 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.enums; | ||||
| 
 | ||||
| /** | ||||
|  *  The data format of the input. Input shape depends on data format (in config):<br> | ||||
|  *  TNS -> [timeSteps, batchSize, inSize]<br> | ||||
|  *  NST -> [batchSize, inSize, timeSteps]<br> | ||||
|  *  NTS -> [batchSize, timeSteps, inSize]<br> */ | ||||
| public enum RnnDataFormat { | ||||
|   TNS, | ||||
| 
 | ||||
|  | ||||
| @ -50,7 +50,7 @@ public abstract class BaseLoss extends DynamicCustomOp { | ||||
|         addArgs(); | ||||
|     } | ||||
| 
 | ||||
|     protected static INDArray getWeights(INDArray weights, INDArray predictions){ | ||||
|     protected static INDArray getWeights(INDArray weights, INDArray predictions) { | ||||
|         return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0); | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -20,29 +20,20 @@ | ||||
| 
 | ||||
| package org.nd4j.linalg.api.ops.impl.loss; | ||||
| 
 | ||||
| import org.nd4j.autodiff.loss.LossReduce; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; | ||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||
| import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class CtcLoss extends BaseLoss { | ||||
| public class CtcLoss extends DynamicCustomOp { | ||||
| 
 | ||||
| 
 | ||||
|     public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){ | ||||
|         super(sameDiff, lossReduce, predictions, weights, labels); | ||||
|     public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){ | ||||
|         super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths}); | ||||
|     } | ||||
| 
 | ||||
|     public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights, | ||||
|                    LossReduce lossReduce) { | ||||
|         this(sameDiff, lossReduce, predictions, weights, label); | ||||
|     } | ||||
| 
 | ||||
|     public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ | ||||
|         super(lossReduce, predictions, weights, labels); | ||||
|     } | ||||
| 
 | ||||
|     public CtcLoss(){ } | ||||
| 
 | ||||
| @ -52,9 +43,9 @@ public class CtcLoss extends BaseLoss { | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<SDVariable> doDiff(List<SDVariable> grad){ | ||||
|     public List<SDVariable> doDiff(List<SDVariable> grad) { | ||||
|         //No external gradient | ||||
|         //Args are: predictions, weights, label | ||||
|         return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); | ||||
|         return new CtcLossBp(sameDiff,  arg(0), arg(1), arg(2),arg(3)).outputs(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -20,17 +20,17 @@ | ||||
| 
 | ||||
| package org.nd4j.linalg.api.ops.impl.loss.bp; | ||||
| 
 | ||||
| import org.nd4j.autodiff.loss.LossReduce; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class CtcLossBp extends BaseLossBp { | ||||
| public class CtcLossBp extends DynamicCustomOp { | ||||
| 
 | ||||
| 
 | ||||
|     public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){ | ||||
|         super(sameDiff, lossReduce, predictions, weights, labels); | ||||
|     public CtcLossBp(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){ | ||||
|         super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths}); | ||||
|     } | ||||
| 
 | ||||
|     public CtcLossBp(){ } | ||||
|  | ||||
| @ -1,22 +1,20 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
|  | ||||
| @ -1,22 +1,20 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| @ -134,7 +132,7 @@ public class NDBitwise { | ||||
| 
 | ||||
|   /** | ||||
|    * Bitwise left cyclical shift operation. Supports broadcasting.<br> | ||||
|    * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br> | ||||
|    * Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br> | ||||
|    * {@code leftShiftCyclic(01110000, 2) -> 11000001}<br> | ||||
|    * | ||||
|    * @param x Input to be bit shifted (INT type) | ||||
| @ -180,7 +178,7 @@ public class NDBitwise { | ||||
| 
 | ||||
|   /** | ||||
|    * Bitwise right cyclical shift operation. Supports broadcasting.<br> | ||||
|    * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br> | ||||
|    * Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br> | ||||
|    * {@code rightShiftCyclic(00001110, 2) -> 10000011}<br> | ||||
|    * | ||||
|    * @param x Input to be bit shifted (INT type) | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.enums.DataFormat; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| @ -41,8 +41,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 2D Convolution layer operation - average pooling 2d<br> | ||||
|    * | ||||
|    * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    * @return output Result after applying average pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -54,9 +53,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 3D convolution layer operation - average pooling 3d <br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    *                         (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    *                         (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling3DConfig Configuration Object | ||||
|    * @return output after applying average pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -161,9 +158,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * Convolution 3D operation with optional bias <br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    * (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights  Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) | ||||
|    * @param bias  Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) | ||||
|    * @param Conv3DConfig Configuration Object | ||||
| @ -180,9 +175,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * Convolution 3D operation with optional bias <br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    * (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights  Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) | ||||
|    * @param Conv3DConfig Configuration Object | ||||
|    * @return output Conv3d output variable (NUMERIC type) | ||||
| @ -196,8 +189,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 2D deconvolution operation with optional bias<br> | ||||
|    * | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) | ||||
|    * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) | ||||
|    * @param DeConv2DConfig Configuration Object | ||||
| @ -214,8 +206,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 2D deconvolution operation with optional bias<br> | ||||
|    * | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) | ||||
|    * @param DeConv2DConfig Configuration Object | ||||
|    * @return output result of deconv2d op (NUMERIC type) | ||||
| @ -263,8 +254,7 @@ public class NDCNN { | ||||
|    * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br> | ||||
|    * = [mb, 2, 4, 4]<br> | ||||
|    * | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format | ||||
|    *                    (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param blockSize Block size, in the height/width dimension | ||||
|    * @param dataFormat Data format: "NCHW" or "NHWC" | ||||
|    * @return output Output variable (NUMERIC type) | ||||
| @ -373,8 +363,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br> | ||||
|    * | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    */ | ||||
|   public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) { | ||||
| @ -385,8 +374,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 2D Convolution layer operation - max pooling 2d <br> | ||||
|    * | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                         (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling2DConfig Configuration Object | ||||
|    * @return output Result after applying max pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -398,9 +386,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * 3D convolution layer operation - max pooling 3d operation.<br> | ||||
|    * | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format | ||||
|    *                         (shape [minibatch, channels, depth, height, width]) or NDHWC format | ||||
|    *                         (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type) | ||||
|    * @param Pooling3DConfig Configuration Object | ||||
|    * @return output Result after applying max pooling on the input (NUMERIC type) | ||||
|    */ | ||||
| @ -412,8 +398,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * Separable 2D convolution operation with optional bias <br> | ||||
|    * | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                      (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) | ||||
|    * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) | ||||
|    * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) | ||||
| @ -432,8 +417,7 @@ public class NDCNN { | ||||
|   /** | ||||
|    * Separable 2D convolution operation with optional bias <br> | ||||
|    * | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format | ||||
|    *                      (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) | ||||
|    * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) | ||||
|    * @param Conv2DConfig Configuration Object | ||||
| @ -471,8 +455,7 @@ public class NDCNN { | ||||
|    * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br> | ||||
|    * = [mb, 2, 4, 4] <br> | ||||
|    * | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format | ||||
|    *                    (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) | ||||
|    * @param blockSize  Block size, in the height/width dimension | ||||
|    * @param dataFormat Data format: "NCHW" or "NHWC" | ||||
|    * @return output Output variable (NUMERIC type) | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.enums.ImageResizeMethod; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.NDValidation; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.autodiff.loss.LossReduce; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.NDValidation; | ||||
| @ -35,7 +35,7 @@ public class NDLoss { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights, | ||||
| @ -71,7 +71,7 @@ public class NDLoss { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param dimension Dimension to perform the cosine distance over | ||||
|    * @return output Cosine distance loss  (NUMERIC type) | ||||
|    */ | ||||
| @ -104,6 +104,25 @@ public class NDLoss { | ||||
|     return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0]; | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * CTC Loss: Connectionist Temporal Classification Loss. See:<br> | ||||
|    * https://dl.acm.org/citation.cfm?id=1143891<br> | ||||
|    * | ||||
|    * @param targetLabels Label array (NUMERIC type) | ||||
|    * @param logitInput Inputs (NUMERIC type) | ||||
|    * @param targetLabelLengths Length of the target label (NUMERIC type) | ||||
|    * @param logitInputLengths Length of the input (NUMERIC type) | ||||
|    * @return output Ctc loss  (NUMERIC type) | ||||
|    */ | ||||
|   public INDArray ctcLoss(INDArray targetLabels, INDArray logitInput, INDArray targetLabelLengths, | ||||
|       INDArray logitInputLengths) { | ||||
|     NDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels); | ||||
|     NDValidation.validateNumerical("ctcLoss", "logitInput", logitInput); | ||||
|     NDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths); | ||||
|     NDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths); | ||||
|     return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(targetLabels, logitInput, targetLabelLengths, logitInputLengths))[0]; | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Hinge loss: a loss function used for training classifiers.<br> | ||||
|    * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br> | ||||
| @ -112,7 +131,7 @@ public class NDLoss { | ||||
|    * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights, | ||||
| @ -152,7 +171,7 @@ public class NDLoss { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param delta Loss function delta value | ||||
|    * @return output Huber loss (NUMERIC type) | ||||
|    */ | ||||
| @ -204,7 +223,7 @@ public class NDLoss { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param epsilon epsilon | ||||
|    * @return output Log loss  (NUMERIC type) | ||||
|    */ | ||||
| @ -237,7 +256,7 @@ public class NDLoss { | ||||
|    * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) | ||||
|    * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param full Boolean flag. true for logPoissonFull, false for logPoisson | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -275,7 +294,7 @@ public class NDLoss { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable, scalar output (NUMERIC type) | ||||
|    */ | ||||
|   public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights, | ||||
| @ -306,13 +325,13 @@ public class NDLoss { | ||||
| 
 | ||||
|   /** | ||||
|    * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br> | ||||
|    * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br> | ||||
|    * When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br> | ||||
|    * this is the mean squared error loss function.<br> | ||||
|    * | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictions Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
|   public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights, | ||||
| @ -325,7 +344,7 @@ public class NDLoss { | ||||
| 
 | ||||
|   /** | ||||
|    * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br> | ||||
|    * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br> | ||||
|    * When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br> | ||||
|    * this is the mean squared error loss function.<br> | ||||
|    * | ||||
|    * @param label Label array (NUMERIC type) | ||||
| @ -357,7 +376,7 @@ public class NDLoss { | ||||
|    * @param label Label array (NUMERIC type) | ||||
|    * @param predictionLogits Predictions array (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param labelSmoothing Label smoothing value. Default value: 0 | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -398,7 +417,7 @@ public class NDLoss { | ||||
|   /** | ||||
|    * Applies the softmax activation function to the input, then implement multi-class cross entropy:<br> | ||||
|    * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br> | ||||
|    * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * otherwise, the output is a scalar.<br> | ||||
|    * <p><br> | ||||
|    * When label smoothing is > 0, the following label smoothing is used:<br> | ||||
| @ -410,7 +429,7 @@ public class NDLoss { | ||||
|    * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) | ||||
|    * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) | ||||
|    * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) | ||||
|    * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} | ||||
|    * @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT | ||||
|    * @param labelSmoothing Label smoothing value. Default value: 0 | ||||
|    * @return output Loss variable (NUMERIC type) | ||||
|    */ | ||||
| @ -425,7 +444,7 @@ public class NDLoss { | ||||
|   /** | ||||
|    * Applies the softmax activation function to the input, then implement multi-class cross entropy:<br> | ||||
|    * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br> | ||||
|    * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br> | ||||
|    * otherwise, the output is a scalar.<br> | ||||
|    * <p><br> | ||||
|    * When label smoothing is > 0, the following label smoothing is used:<br> | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.enums.PartitionMode; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.enums.PadMode; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; | ||||
| import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; | ||||
| @ -85,22 +85,22 @@ public class NDRNN { | ||||
| 
 | ||||
|   /** | ||||
|    * Long Short-Term Memory layer - Hochreiter 1997.<br> | ||||
|    * SUPPORTS following data formats:\n<br> | ||||
|    * for unidirectional: \n" +<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]\n<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]\n<br> | ||||
|    * SUPPORTS following data formats:<br> | ||||
|    * for unidirectional:<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]<br> | ||||
|    * NTS: shapes [numExamples, timeLength, inOutSize]<br> | ||||
|    * for bidirectional:\n<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br> | ||||
|    * SUPPORTS following direction modes:\n<br> | ||||
|    * for bidirectional:<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br> | ||||
|    * SUPPORTS following direction modes:<br> | ||||
|    * FWD: forward<br> | ||||
|    * BWD: backward<br> | ||||
|    * BIDIR_SUM: bidirectional sum\n<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat\n" +<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br> | ||||
|    * BIDIR_SUM: bidirectional sum<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br> | ||||
|    * You may use different gate configurations:<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br> | ||||
|    * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br> | ||||
|    * | ||||
|    * @param x  Input, with shape dependent on the data format (in config). (NUMERIC type) | ||||
| @ -121,22 +121,22 @@ public class NDRNN { | ||||
| 
 | ||||
|   /** | ||||
|    * Long Short-Term Memory layer - Hochreiter 1997.<br> | ||||
|    * SUPPORTS following data formats:\n<br> | ||||
|    * for unidirectional: \n" +<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]\n<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]\n<br> | ||||
|    * SUPPORTS following data formats:<br> | ||||
|    * for unidirectional:<br> | ||||
|    * TNS: shapes [timeLength, numExamples, inOutSize]<br> | ||||
|    * NST: shapes [numExamples, inOutSize, timeLength]<br> | ||||
|    * NTS: shapes [numExamples, timeLength, inOutSize]<br> | ||||
|    * for bidirectional:\n<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br> | ||||
|    * SUPPORTS following direction modes:\n<br> | ||||
|    * for bidirectional:<br> | ||||
|    * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br> | ||||
|    * SUPPORTS following direction modes:<br> | ||||
|    * FWD: forward<br> | ||||
|    * BWD: backward<br> | ||||
|    * BIDIR_SUM: bidirectional sum\n<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat\n" +<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br> | ||||
|    * BIDIR_SUM: bidirectional sum<br> | ||||
|    * BIDIR_CONCAT: bidirectional concat<br> | ||||
|    * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br> | ||||
|    * You may use different gate configurations:<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br> | ||||
|    * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br> | ||||
|    * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br> | ||||
|    * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br> | ||||
|    * | ||||
|    * @param x  Input, with shape dependent on the data format (in config). (NUMERIC type) | ||||
|  | ||||
| @ -1,25 +1,25 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== | ||||
| 
 | ||||
| package org.nd4j.linalg.factory.ops; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.NDValidation.isSameType; | ||||
| 
 | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user