commit
0c81654567
|
@ -0,0 +1,614 @@
|
||||||
|
# ND4J Op Definitions and Code Generation
|
||||||
|
This project contains the ND4J Op definitions, the DSL (Domain Specific Language) that is used for those definitions and
|
||||||
|
code generators that use those definitions to create the actual Java code that is used to use the defined operations.
|
||||||
|
|
||||||
|
|
||||||
|
## Why define ops externally?
|
||||||
|
As we started to support SameDiff, we also started to introduce inconsistencies between SameDiff and ND4J. Even though
|
||||||
|
both of those libraries use the same underlying implementations for operations, there are both small and large
|
||||||
|
differences in the API that we provide for them. Sometimes, we have provided an official API only for one usage, and not
|
||||||
|
the other. And very often the documentation for a single op is in many different places.
|
||||||
|
|
||||||
|
In the future we want to support other programming languages with libnd4j, and provide more ways to use our C++ backend.
|
||||||
|
This would only increase the aforementioned problems.
|
||||||
|
|
||||||
|
The root of all of those problems, is that Ops are used across different environments, and there is no single way of
|
||||||
|
defining them with an enforced interface.
|
||||||
|
|
||||||
|
|
||||||
|
## How does this project help with enforcing a single consistent interface for ops?
|
||||||
|
The solution we propose, is to define the operations separately, and then generate the necessary API code for them. All
|
||||||
|
of the generated code is to be considered untouchable, editing it will result in the changes being overwritten sooner
|
||||||
|
rather than later.
|
||||||
|
|
||||||
|
The combination of external op definition and code generation, opens up many opportunities for us. The first one being
|
||||||
|
that we can easily create consistent APIs for both ND4J and SameDiff in Java. But, looking into the future, we can also
|
||||||
|
create those APIs for other programming languages like Python, Swift, or even C#. We can even go beyond programming
|
||||||
|
languages, and use the op definitions to create better documentation than what JavaDoc or similar might support out of
|
||||||
|
the box.
|
||||||
|
|
||||||
|
## Maintenance
|
||||||
|
This project is currently maintained by Paul Dubs, with feedback often collected from raver119 and Alex Black.
|
||||||
|
|
||||||
|
## Current Status
|
||||||
|
At the moment we still focus on nailing down an easily readable and contribution friendly DSL for op definition and code
|
||||||
|
generation that can replace namespace definitions. This means that at the moment we still rely on the pre-existing Op
|
||||||
|
definition classes that already exist in ND4J.
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
* Replace Bitwise and Random namespaces with autogenerated code – In progress.
|
||||||
|
* Implement a convenient CLI tool.
|
||||||
|
* Define all Ops using the DSL.
|
||||||
|
* Automatically generate derivative op declarations from existing ops
|
||||||
|
* Replace all namespace definitions in ND4J / SameDiff with automatically generated ones
|
||||||
|
* Replace all Op classes with automatically generated ones.
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
Pre-requisites:
|
||||||
|
* JDK 8 or higher
|
||||||
|
* Maven 3.3 or higher
|
||||||
|
|
||||||
|
TODO: Show usage output of the project itself
|
||||||
|
|
||||||
|
TODO: Show how to use from mvn
|
||||||
|
|
||||||
|
|
||||||
|
## Generating Code - ND4J Namespaces
|
||||||
|
|
||||||
|
A script - `generate.sh` - is provided in the project root. This can be used (at present) to generate ND4J namespace classes.
|
||||||
|
It is assumed that the deeplearning4j mono repo and the dl4j-dev-tools repo both exist and have a common parent directory
|
||||||
|
i.e., `somedir/deeplearning4j` and `somedir/dl4j-dev-tools` both exist.
|
||||||
|
|
||||||
|
The script takes as argument the name (or names) of the ND4J namespaces to generate (not case sensitive) and projects (supported
|
||||||
|
projects are nd4j, sd and both by default).
|
||||||
|
|
||||||
|
As of 26/11, namespaces names (and hence valid args) include: `bitwise`, `neuralnetwork`, `random`, and `math`
|
||||||
|
Note also that `all` may be passed to the script to generate all namespaces.
|
||||||
|
|
||||||
|
For example, to generate both bitwise and random namespaces for both nd4j and SameDiff:
|
||||||
|
```
|
||||||
|
./generate.sh bitwise,random
|
||||||
|
```
|
||||||
|
Or to generate all namespaces for both nd4j and SameDiff, use:
|
||||||
|
```
|
||||||
|
./generate.sh all
|
||||||
|
```
|
||||||
|
To generate namespaces for one project only, use:
|
||||||
|
```
|
||||||
|
./generate.sh linalg -projects sd
|
||||||
|
```
|
||||||
|
or:
|
||||||
|
```
|
||||||
|
./generate.sh linalg -projects nd4j
|
||||||
|
```
|
||||||
|
The script will first compile the project, before running.
|
||||||
|
Internally, the `org.nd4j.codegen.cli.CLI` class is used.
|
||||||
|
Classes are written to `deeplearning4j/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/`
|
||||||
|
|
||||||
|
## Generating documentation.
|
||||||
|
It is possible to use generate.sh for generation of code only, docs in markdown format only, or both docs and code.
|
||||||
|
To generate docs only and store them to new folder "docs" for all namespaces:
|
||||||
|
```
|
||||||
|
./generate.sh all -docsdir ../../docs
|
||||||
|
```
|
||||||
|
Generation for selected namespaces works in the same way as for code:
|
||||||
|
```
|
||||||
|
./generate.sh -docsdir ../../docs bitwise,linalg
|
||||||
|
```
|
||||||
|
|
||||||
|
# Code structure
|
||||||
|
The project is implemented using a mix of Java and Kotlin. The DSL definition and the accompanying data structures are
|
||||||
|
implemented in Kotlin. At the moment the code generators are implemented in Java, in order to allow people who are not
|
||||||
|
fluent in Kotlin, but know Java to be able to contribute to the code generators.
|
||||||
|
|
||||||
|
The source code for this project is structured a bit different that what you would typically see in a Java or Kotlin
|
||||||
|
project. When you take a look inside the `src/main` directory, you will find 4 sub-directories.
|
||||||
|
|
||||||
|
The `java` and `kotlin` directories contain Java and Kotlin code respectively.
|
||||||
|
|
||||||
|
In order to not confuse op definitions with the machinery that allows them to be defined in that way, ops are kept in a
|
||||||
|
separate folder called `ops`.
|
||||||
|
|
||||||
|
Because we use JavaPoet for Java code generator implementation, we also have yet another folder called `stubs`. That
|
||||||
|
folder contains stub classes, that are used to reference other classes available in ND4J. These stub classes are
|
||||||
|
intentionally left empty, as JavaPoet only requires them for naming and automatically creating proper imports. We use
|
||||||
|
stub classes instead of depending on the actual nd4j API in order to break a cyclic dependency that would otherwise be
|
||||||
|
created (i.e. in order to be able to generate code for ND4J, we would need an already compiled nd4j to be available).
|
||||||
|
**Note:** If something is stubbed here and is moved in ND4J, then it also has to be moved to the appropriate place here,
|
||||||
|
otherwise the generated code will be wrong.
|
||||||
|
|
||||||
|
The `adr` folder contains "Architecture Decision Records". These files give you more insight into the "why" of some of
|
||||||
|
the bigger decisions within this project.
|
||||||
|
|
||||||
|
# DSL for Op Definition
|
||||||
|
Ops are defined using a DSL that is implemented in Kotlin. This means that other than the DSL, as defined in the
|
||||||
|
following, you can also use all of Kotlin when defining Ops. However, doing things the obvious and clearly
|
||||||
|
understandable way is better than coming up with a clever way, so prefer to use the DSL as described if unsure.
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
val mathNs = Namespace("math") {
|
||||||
|
Op("add") {
|
||||||
|
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic"
|
||||||
|
|
||||||
|
Input(NUMERIC, "x") { description = "First input to add" }
|
||||||
|
Input(NUMERIC,"y") { count = AtLeast(1); description = "Second input to add" }
|
||||||
|
Arg(INT,"shape") { count = AtLeast(1); description = "shape" }
|
||||||
|
|
||||||
|
|
||||||
|
Output(NUMERIC, "z") { description = "Output (x+y)" }
|
||||||
|
|
||||||
|
Doc(Language.ANY, DocScope.ALL) {
|
||||||
|
"""
|
||||||
|
(From AddOp) Add op doc text that will appear everywhere - classes, constructors, op creators
|
||||||
|
""".trimIndent()
|
||||||
|
}
|
||||||
|
Doc(Language.ANY, DocScope.CLASS_DOC_ONLY) {
|
||||||
|
"Add op doc text that will appear in all class docs (javadoc etc)"
|
||||||
|
}
|
||||||
|
Doc(Language.ANY, DocScope.CONSTRUCTORS_ONLY) {
|
||||||
|
"Add op doc text for constructors only"
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This example shows how a namespace is defined. Namespaces are at the top layer, and ops can only be defined within the
|
||||||
|
context of a namespace. This example namespace contains only a single op, called "add". If we wanted to add another op,
|
||||||
|
we would simply add it below the first.
|
||||||
|
|
||||||
|
As you can see, every op has to have a name, if you try to create one without a name, you will get a compile error.
|
||||||
|
Within the context of the op, we first set in which java package the op class can be found in, then define its inputs,
|
||||||
|
arguments and outputs and finally add some free form documentation about what that op is doing.
|
||||||
|
|
||||||
|
Like with the op itself, the inputs, arguments and outputs all have to have a name, but unlike the op, they also require
|
||||||
|
a type. Within their context, you can set a description and a count of how many parameters they can take respectively.
|
||||||
|
|
||||||
|
If an input, argument or output take anything else than exactly 1, they will be treated as arrays. Typically you would
|
||||||
|
use this to define ops like `concat` which can take multiple input tensors or ops that might take shape arguments.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
The following shows how a typical op definition looks like and how the generated Java code may look.
|
||||||
|
|
||||||
|
An op might be defined like this:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
Op("binomial") {
|
||||||
|
javaPackage = "org.nd4j.linalg.api.ops.random.custom"
|
||||||
|
Arg(INT, "nTrials") { description = "Number of trials parameter for the binomial distribution" }
|
||||||
|
Arg(FLOATING_POINT, "p") { description = "Probability of success for each trial" }
|
||||||
|
Arg(INT, "shape") { count = AtLeast(1); description = "Shape of the new random SDVariable, as a 1D array" }
|
||||||
|
|
||||||
|
Output(NUMERIC, "output") { description = "new random SDVariable, where values are randomly sampled according to a Binomial distribution" }
|
||||||
|
|
||||||
|
Doc(Language.ANY, DocScope.ALL) {
|
||||||
|
"""
|
||||||
|
Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
|
||||||
|
with the specified number of trials and probability.
|
||||||
|
""".trimIndent()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The java code generator will create a method like the following for it:
|
||||||
|
```java
|
||||||
|
/**
|
||||||
|
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
|
||||||
|
* with the specified number of trials and probability.
|
||||||
|
*
|
||||||
|
* @param nTrials Number of trials parameter for the binomial distribution
|
||||||
|
* @param p Probability of success for each trial
|
||||||
|
* @param shape Shape of the new random SDVariable, as a 1D array (Size: AtLeast(min=1))
|
||||||
|
* @return output new random SDVariable, where values are randomly sampled according to a Binomial distribution (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public static INDArray binomial(long nTrials, double p, long... shape) {
|
||||||
|
Preconditions.checkArgument(shape.length >= 1, "shape has incorrect count. Expected: AtLeast(min=1)");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.BinomialOp(nTrials, p, shape))[0];
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Or an op with some more constraints:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
Op("and") {
|
||||||
|
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
|
||||||
|
val x = Input(INT, "x") { description = "First input array" }
|
||||||
|
val y = Input(INT, "y") { description = "Second input array" }
|
||||||
|
Constraint("Must be same types"){ sameType(x, y) }
|
||||||
|
Constraint("Must have broadcastable shapes"){ broadcastableShapes(x, y) }
|
||||||
|
|
||||||
|
Output(INT, "output"){ description = "Bitwise AND array" }
|
||||||
|
|
||||||
|
Doc(Language.ANY, DocScope.ALL){
|
||||||
|
"""
|
||||||
|
Bitwise AND operation. Supports broadcasting.
|
||||||
|
""".trimIndent()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
will be converted to java like this:
|
||||||
|
|
||||||
|
```java
|
||||||
|
/**
|
||||||
|
* Bitwise AND operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* Inputs must satisfy the following constraints:
|
||||||
|
* Must be same types: isSameType(x, y)
|
||||||
|
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
|
||||||
|
*
|
||||||
|
* @param x First input array (INT type)
|
||||||
|
* @param y Second input array (INT type)
|
||||||
|
* @return output Bitwise AND array (INT type)
|
||||||
|
*/
|
||||||
|
public static INDArray and(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateInteger("and", x);
|
||||||
|
NDValidation.validateInteger("and", y);
|
||||||
|
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||||
|
Preconditions.checkArgument(isBroadcastableShapes(x, y), "Must have broadcastable shapes");
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.AndOp(x, y))[0];
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# Full DSL Description
|
||||||
|
## Namespace
|
||||||
|
|
||||||
|
fun NamespaceName() = Namespace("name"){ /* Op definitions in namespace context */}
|
||||||
|
|
||||||
|
Defines a namespace.
|
||||||
|
|
||||||
|
## Op
|
||||||
|
Only available within a namespace context
|
||||||
|
|
||||||
|
Op("opName") { /* op properties in op context */ }
|
||||||
|
Op("anotherOp", mixin) { /* op properties in op context */ }
|
||||||
|
Op("anotherOp2", mixin, keepInputs=false) { /* op properties in op context */ }
|
||||||
|
|
||||||
|
Every op requires a namespace unique op name.
|
||||||
|
|
||||||
|
When defining an op, you can also pass a mixin that it should inherit initial properties from. This has the same effect
|
||||||
|
as using `useMixin(mixin)` as the very first thing in the op definition. If you don't want to inherit all of the
|
||||||
|
parameters of the mixin, you can pass the same additional configuration as you would pass to
|
||||||
|
`useMixin(mixin, ...options..)`. See [Mixin](#mixin) for more information.
|
||||||
|
|
||||||
|
### Op properties
|
||||||
|
* `javaPackage` (String): Package where the op is to be found in the java implementation.
|
||||||
|
* `javaOpClass` (String): Name of java op class if inconsistent with opName. Default: same as opName
|
||||||
|
* `libnd4jName` (String): The name the op has in libnd4j. Default: same as opName
|
||||||
|
|
||||||
|
|
||||||
|
## Mixin
|
||||||
|
Available in global context.
|
||||||
|
|
||||||
|
val mixin = Mixin("name"){ /* op properties in op context */ }
|
||||||
|
// within an op context:
|
||||||
|
useMixin(mixin)
|
||||||
|
useMixin(mixin, ...options...)
|
||||||
|
|
||||||
|
// When needing to access something from within the mixin
|
||||||
|
mixin.input("name")
|
||||||
|
mixin.arg("name")
|
||||||
|
mixin.config("name")
|
||||||
|
mixin.output("name")
|
||||||
|
|
||||||
|
Mixins provide the facility to share commonalities between Ops. You can think of it like inheritance, especially when
|
||||||
|
you declare the use of a mixin on Op definition. In contrast to normal (single) inheritance where only a single super
|
||||||
|
class is possible, the mixin mechanism allows to "inherit" from multiple sources.
|
||||||
|
|
||||||
|
You can define almost all the same things within a mixin that you can within an Op. The only things that *can not* be
|
||||||
|
configured within a mixin are Op `name`, `libnd4jName` and `javaOpClass`.
|
||||||
|
|
||||||
|
As mixins can be configured within the global context, you can share them across namespaces by defining them in their
|
||||||
|
own file. If a mixin is namespace specific, you can also define it within the namespace context.
|
||||||
|
|
||||||
|
Mixins are used either on definition as a parameter `Op("opname", mixin){...}`, or with `useMixin(mixin)` within the op
|
||||||
|
definition. While the former version only supports a single mixin, the latter version allows you to use as many mixins
|
||||||
|
as are required.
|
||||||
|
|
||||||
|
You can also build up mixins by using `useMixin(mixin)` inside a Mixin itself.
|
||||||
|
|
||||||
|
`useMixin(mixin, ...options...)` supports a few additional options: `keepInputs`, `keepArgs`, `keepConfigs`,
|
||||||
|
`keepOutputs`, `keepSignatures`, `keepDoc`, `keepConstraints`. They default to `true`. If you want to skip including
|
||||||
|
some of them, you simply set the parameter for it to `false`, e.g. `useMixin(mixin, keepDoc=false)`.
|
||||||
|
|
||||||
|
When using `useMixin(mixin)`, all definitions within the mixin are applied as if this invocation was replaced with the
|
||||||
|
content of the mixin itself. This means, that if you have already defined anything prior to using a mixin, the mixin's
|
||||||
|
definitions will be **after** the previously defined things. This can be very useful if the commonality between ops is
|
||||||
|
that they have a few trailing options.
|
||||||
|
|
||||||
|
If a named property or section is defined in both a mixin (or multiple mixins) and the op, then the **last** to define it will
|
||||||
|
win. Named properties are `legacy`, `javaPackage`, named sections are `Input`, `Arg`, `Output`, `Config`.
|
||||||
|
|
||||||
|
For example, assume you have `javaPackage` defined in both an op and a mixin. Then you can have the following two
|
||||||
|
cases:
|
||||||
|
|
||||||
|
First case:
|
||||||
|
```kotlin
|
||||||
|
Op("foo"){
|
||||||
|
useMixin(exampleMixin)
|
||||||
|
javaPackage = "some.example.package"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Second case:
|
||||||
|
```kotlin
|
||||||
|
Op("foo"){
|
||||||
|
javaPackage = "some.example.package"
|
||||||
|
useMixin(exampleMixin)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
In the first case, the op will have the `javaPackage` value that is defined within the op. In the second case it will
|
||||||
|
have the `javaPackage` value defined in the mixin.
|
||||||
|
|
||||||
|
For inputs, args, outputs, it works similarly. Assume you have `Input(dataType, "a")` defined in both the mixin and the
|
||||||
|
op. Again you can have two cases:
|
||||||
|
|
||||||
|
First case:
|
||||||
|
```kotlin
|
||||||
|
Op("foo"){
|
||||||
|
useMixin(exampleMixin)
|
||||||
|
Input(NUMERIC, "a")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Second case:
|
||||||
|
```kotlin
|
||||||
|
Op("foo"){
|
||||||
|
Input(NUMERIC, "a")
|
||||||
|
useMixin(exampleMixin)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
In the first case, it will overwrite the input from the mixin. In the second case, the mixin will overwrite that the
|
||||||
|
input from the op.
|
||||||
|
|
||||||
|
## Config
|
||||||
|
Only available within a namespace context
|
||||||
|
|
||||||
|
val nameConfig = Config("Name"){
|
||||||
|
/* input, arg, constraint, doc properties */
|
||||||
|
}
|
||||||
|
|
||||||
|
Every config requires a namespace unique name.
|
||||||
|
|
||||||
|
A config allows to define a configuration class, that can be used as a holder for complex properties of specific ops
|
||||||
|
which will be passed to an op as a parameter.
|
||||||
|
|
||||||
|
Similar to an op itself, it supports `Input`, `Arg`, `Constraint` and `Doc` definitions.
|
||||||
|
|
||||||
|
in order to use the config within an op you either use `useConfig(cfg)` or `val configRef = useConfig(cfg)`. The second
|
||||||
|
form allows you to reference the config.
|
||||||
|
|
||||||
|
Referencing the config allows to you reference its inputs and args by name: `configRef.input("name")` and
|
||||||
|
`configRef.arg("name")`. Also it allows you to use a config in a signature `Signature(a, b, c, configRef)`.
|
||||||
|
|
||||||
|
When default and shorthand signatures are used, configs will be always placed at the end.
|
||||||
|
|
||||||
|
If a config is defined but not used, an `IllegalStateException` will be thrown.
|
||||||
|
|
||||||
|
See also [ADR 0007 "Configuration Objects"](adr/0007-configuration_objects.md).
|
||||||
|
|
||||||
|
|
||||||
|
## Input
|
||||||
|
Available within an op, mixin and a config context
|
||||||
|
|
||||||
|
Input(FLOATING_POINT, "b"){ /* input properties in input context */ }
|
||||||
|
val a = Input(INT, "a"){ /* input properties in input context */ }
|
||||||
|
|
||||||
|
Inputs represent tensors. They are what the op will work on.
|
||||||
|
|
||||||
|
Every input requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
|
||||||
|
|
||||||
|
When defining an input, you can assign it to a variable in order to be able to reference it later on. You might want to
|
||||||
|
do this when defining constraints.
|
||||||
|
|
||||||
|
If you want an input to represent an array, you will have to set a count accordingly. If no count is set, it is assumed
|
||||||
|
that the count is meant to be `Exactly(1)`.
|
||||||
|
|
||||||
|
### Input properties
|
||||||
|
* `description` (String): A short description what this input represents. Setting this is recommended.
|
||||||
|
* `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)`
|
||||||
|
* `defaultValue` (Input): use another input as the default if this isn't set explicitly. The data type of the other
|
||||||
|
input has to match the data type of this input. The other input may also have a default value.
|
||||||
|
|
||||||
|
## Argument
|
||||||
|
Available within an op, mixin and config context
|
||||||
|
|
||||||
|
Arg(FLOATING_POINT, "b"){ /* Arg properties in arg context */ }
|
||||||
|
val a = Arg(INT, "a"){ /* Arg properties in arg context */ }
|
||||||
|
|
||||||
|
Args represent arguments. They modify how the op works on its inputs.
|
||||||
|
|
||||||
|
Every arg requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
|
||||||
|
|
||||||
|
When defining an arg, you can assign it to a variable in order to be able to reference it later on. You might want to do
|
||||||
|
this when defining constraints.
|
||||||
|
|
||||||
|
If you want an arg to represent an array, you will have to set a count accordingly. If no count is set, it is assumed
|
||||||
|
that the count is meant to be `Exactly(1)`.
|
||||||
|
|
||||||
|
Note (Java specific): If the last arg is defined to represent an array, it will be translated to a vararg parameter, e.g.
|
||||||
|
`Arg(INT, "a"){ count = AtLeast(1); description = "..." }` will be turned into `long... a`.
|
||||||
|
|
||||||
|
### Argument properties
|
||||||
|
* `description` (String): A short description what this argument represents. Setting this is recommended.
|
||||||
|
* `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)`
|
||||||
|
* `defaultValue` (null|Number|Boolean|int[]|double[]|boolean[]|Arg|TensorShapeValue|TensorDataTypeValue|String):
|
||||||
|
Use given value as default value, if this isn't explicitly set. Can refer to *inputs* and *outputs* using `x.shape()`
|
||||||
|
and `x.dataType()`. The given default values has to match the data type for this argument. May also refer to another
|
||||||
|
Arg, and that Arg may also have a default value. Default values based on outputs are treated like without a default
|
||||||
|
in SameDiff mode.
|
||||||
|
* `possibleValues` (String[]): only available when ENUM data type is used for the argument. Takes a list of possible
|
||||||
|
values for the Enum. If used in in abstract base op, the enum will only be created once. See also
|
||||||
|
[ADR 0006 "Op specific enums"](adr/0006-op_specific_enums.md).
|
||||||
|
|
||||||
|
|
||||||
|
## Output
|
||||||
|
Only available within an op and mixin context
|
||||||
|
|
||||||
|
Output(FLOATING_POINT, "b"){ /* Arg properties in arg context */ }
|
||||||
|
|
||||||
|
Every output requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
|
||||||
|
|
||||||
|
While outputs can be assigned to a variable, there is no intended use-case for it. In contrast to inputs and args,
|
||||||
|
outputs can not be used in constraints.
|
||||||
|
|
||||||
|
### Output properties
|
||||||
|
* `description` (String): A short description what this argument represents. Setting this is recommended.
|
||||||
|
|
||||||
|
|
||||||
|
## Signature
|
||||||
|
Only available within an op and mixin context
|
||||||
|
|
||||||
|
Signature(a,b,c)
|
||||||
|
Signature(a,b,c) { "Some Documentation" }
|
||||||
|
AllParamSignature()
|
||||||
|
AllDefaultParamSignature()
|
||||||
|
|
||||||
|
For some ops only specific signatures make sense, as for example some optional parameters may become required in the
|
||||||
|
presence of other optional parameters. This feature is mainly meant to help with the fact that not all programming
|
||||||
|
languages (e.g. Java) support default parameters. Each signature is meant to describe one overload in those languages.
|
||||||
|
|
||||||
|
See also [ADR 0005 "Optional parameters and signatures"](adr/0005-optional_parameters_and_signatures.md).
|
||||||
|
|
||||||
|
Signatures can also reference the output(s) of an op. Those signatures are only relevant in NDArray programming mode.
|
||||||
|
They are not to be generated in SameDiff mode.
|
||||||
|
|
||||||
|
`AllParamSignature()` and `AllDefaultParamSignature()` are short hands for `Signature(...all parameters...)` and
|
||||||
|
`Signature(...only parameters with no default values...)`. Their parameters include references to outputs unless
|
||||||
|
disabled using `withOutput=false` (e.g. `AllParamSignature(withOutput=false)`).
|
||||||
|
|
||||||
|
If no signature is specified for an op, it is treated as if `AllParamSignature()` and `AllDefaultParamSignature()` are
|
||||||
|
both specified.
|
||||||
|
|
||||||
|
Each signature must satisfy the condition, that all required parameters are listed there. If this condition is not
|
||||||
|
satisfied, an `IllegalStateException` will be thrown on construction.
|
||||||
|
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
Only available within an op and mixin context
|
||||||
|
|
||||||
|
Doc(Language.ANY, DocScope.ALL){
|
||||||
|
""" Some documentation
|
||||||
|
It can be multiline. And indented.
|
||||||
|
""".trimIndent()
|
||||||
|
}
|
||||||
|
|
||||||
|
Documentation can be language specific, and can be set to be only given at specific places. The documentation itself is
|
||||||
|
given as a string. Because Kotlin supports multiline strings along with proper indentation, we are using them directly
|
||||||
|
here.
|
||||||
|
|
||||||
|
Note: At the moment we are only creating java code, so the documentation can use JavaDoc syntax.
|
||||||
|
|
||||||
|
You can have multiple Doc definitions; they are treated as additive.
|
||||||
|
|
||||||
|
Any instances of the following values will be replaced when generating code:
|
||||||
|
|
||||||
|
* `%OPNAME%` -> operation name ("Add", "Sub", etc)
|
||||||
|
* `%LIBND4J_OPNAME%` -> libnd4j op name ("add", "sub", etc)
|
||||||
|
* `%INPUT_TYPE%` -> input / output type depending on the generated api, i.e. `SDVariable` for SameDiff and `INDArray`
|
||||||
|
for ND4J
|
||||||
|
|
||||||
|
See `DocTokens` class for more details.
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
Available within an op, mixin and a config context.
|
||||||
|
|
||||||
|
Constraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ }
|
||||||
|
BackendConstraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ }
|
||||||
|
|
||||||
|
Many ops expect their inputs and arguments to satisfy some specific rules. Those rules can be expressed with the
|
||||||
|
constraint system.
|
||||||
|
|
||||||
|
Constraints are to be enforced within the frontend language, while BackendConstraints are currently only to be used as
|
||||||
|
a part of the documentation. They will be enforced within the C++ backend, so there is no point in double checking them.
|
||||||
|
|
||||||
|
There is a system in place to define even complex constraints for inputs and arguments.
|
||||||
|
|
||||||
|
In a constraint definition, you can reference inputs and arguments directly, if they are previously assigned to
|
||||||
|
a variable using `val name = Input(...)`. Inside the Constraint block, you can use the following operations:
|
||||||
|
|
||||||
|
* `eq`: Compare equality (applicable to numbers and booleans), e.g. `x eq 7`, `x eq true`
|
||||||
|
* `neq`: Compare inequality (applicable to numbers and booleans), e.g. `x neq 3`, `x neq true`
|
||||||
|
* `lt`, `lte`: less than, less than equal (applicable to numbers), e.g. `x lt 3`, `x lte 4`
|
||||||
|
* `gt`, `gte`: greater than, grater than equal (applicable to numbers), e.g. `x gt 5`, `x gte 6`
|
||||||
|
* `and`: combine two comparisons where both have to be true, e.g. `(x eq 8) and (y lt 3)`
|
||||||
|
* `or`: combine two comparisons where one has to be true, e.g. `(x eq 8) or (y eq true)`
|
||||||
|
* `all`: combine N comparisons where all have to be true, e.g. `all(x eq 8, y lt 3, z eq true)`
|
||||||
|
* `some`: combine N comparisons where at least one has to be true, e.g. `some(x eq 8, y lt 3, z eq true)`
|
||||||
|
* `not`: negates a comparison, e.g. `not(x eq 3)`
|
||||||
|
|
||||||
|
In addition to those operations, you also get access to some more complex constraints:
|
||||||
|
* `sameType(...)`: true if all given inputs are the same type, e.g. `sameType(x,y,z)`
|
||||||
|
* `sameShape(...)`: true if all given inputs have the same shape, e.g. `sameShape(x,y,z)`
|
||||||
|
* `broadcastableShapes(...)`: true if all given inputs have broadcast compatible shapes, e.g. `broadcastableShapes(x,y,z)`
|
||||||
|
|
||||||
|
Inputs also get some additional methods on them to define useful constraints:
|
||||||
|
* `input.rank()`: Rank of the given input
|
||||||
|
* `input.sizeAt(i)`: size of the given input at the i-th dimension
|
||||||
|
* `input.isScalar()`: Short hand for `x.rank() == 1`
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
Some examples of constraints, and what they evaluate to. The example code contains a little bit of context.
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
val x = Input(INT, "x") { description = "First input array" }
|
||||||
|
val y = Input(INT, "y") { description = "Second input array" }
|
||||||
|
Constraint("foo bar"){
|
||||||
|
x.sizeAt(7) eq 7 and y.isScalar()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
will evaluate to:
|
||||||
|
```java
|
||||||
|
Preconditions.checkArgument((x.sizeAt(7) == 7) && (y.rank() == 1), "foo bar");
|
||||||
|
```
|
||||||
|
|
||||||
|
More examples (only the constraint itself, without context code):
|
||||||
|
|
||||||
|
#### Some
|
||||||
|
```kotlin
|
||||||
|
some(input.rank() eq 3, input.sizeAt(2) gte 7, input.sizeAt(4) lt 5)
|
||||||
|
```
|
||||||
|
turns to:
|
||||||
|
```java
|
||||||
|
((x.rank() == 3) || (x.sizeAt(2) >= 7)) || (x.sizeAt(4) < 5)
|
||||||
|
```
|
||||||
|
|
||||||
|
# Contributing to this project
|
||||||
|
If you want to contribute to this project other than by adding or improving op definitions, the following sections might
|
||||||
|
be of special interest to you.
|
||||||
|
|
||||||
|
## Extending the DSL
|
||||||
|
The DSL is implemented using Kotlin’s type-safe builders feature
|
||||||
|
(see https://kotlinlang.org/docs/reference/type-safe-builders.html). The basic principle is that functions calls can
|
||||||
|
receive blocks that can be executed in a specified context. When combined with the fact that we are just looking to
|
||||||
|
create an object graph that is then going to be used as input to the code generators, this allows us to create a very
|
||||||
|
feature rich DSL without actually having to write a lot of code to support it.
|
||||||
|
|
||||||
|
Most of the DSL specific code can be found in `src/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt`. The actual class
|
||||||
|
definitions for the object graph we are building, can be found in `src/kotlin/org/nd4j/codegen/api`.
|
||||||
|
|
||||||
|
If you want to add just a simple field to one of the objects, it is usually enough to directly add it to the particular
|
||||||
|
class.
|
||||||
|
|
||||||
|
If you want to add a specific section to the op definition, i.e. a section like Input or Doc, you will have to add both
|
||||||
|
the class for the object that it is going to be creating, as well as a function within OpBuilder.kt to create and
|
||||||
|
register that section within the op.
|
||||||
|
|
||||||
|
**Note:** When you extend the DSL you will most likely also have to update all code generators to support the feature
|
||||||
|
you have added.
|
||||||
|
|
||||||
|
## Adding / extending code generators
|
||||||
|
Code generators can be written in either Java or Kotlin. Java has the advantage that more people will have experience in
|
||||||
|
using it. Kotlin has the advantage of more convenient syntax, especially for plain string manipulation and when dealing
|
||||||
|
with Enums and fixed sets of subclasses (called sealed classes in Kotlin).
|
||||||
|
|
||||||
|
All generators have to implement the `org.nd4j.codegen.api.generator.Generator` interface. For automatic detection by
|
||||||
|
the CLI tool, they should also be within the `org.nd4j.codegen.impl.LANGUAGE` package, where `LANGUAGE` is the actual
|
||||||
|
language that they generate.
|
||||||
|
|
||||||
|
Code generators can also use an auxiliary generator for constraint generation. Those auxiliary generators, have to
|
||||||
|
implement ` org.nd4j.codegen.api.generator.ConstraintCodeGenerator` interface.
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if test "$#" -eq 0; then
|
||||||
|
echo "No namespaces were specified. One or more namespaces must be provided as an argument"
|
||||||
|
echo "Usage example 1 (single namespace): ./generate.sh math"
|
||||||
|
echo "Usage example 2 (multiple namespaces): ./generate.sh math,random"
|
||||||
|
echo "Usage example 2 (all namespaces): ./generate.sh all"
|
||||||
|
else
|
||||||
|
mvn clean package -DskipTests
|
||||||
|
java -cp target/codegen-1.0.0-SNAPSHOT-shaded.jar org.nd4j.codegen.cli.CLI -dir ../../ -namespaces "$@"
|
||||||
|
fi
|
|
@ -0,0 +1,258 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>codegen</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
<commonsio.version>2.5</commonsio.version>
|
||||||
|
<commons.lang.version>3.12.0</commons.lang.version>
|
||||||
|
<commons.dbutils.version>1.7</commons.dbutils.version>
|
||||||
|
<lombok.version>1.18.8</lombok.version>
|
||||||
|
<logback.version>1.1.7</logback.version>
|
||||||
|
<junit.version>4.12</junit.version>
|
||||||
|
<junit-jupiter.version>5.4.2</junit-jupiter.version>
|
||||||
|
<java.version>1.8</java.version>
|
||||||
|
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
|
||||||
|
<kotlin.version>1.3.50</kotlin.version>
|
||||||
|
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
|
||||||
|
<kotlin.compiler.incremental>true</kotlin.compiler.incremental>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
<version>1.7.28</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>ch.qos.logback</groupId>
|
||||||
|
<artifactId>logback-classic</artifactId>
|
||||||
|
<version>${logback.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>commons-io</groupId>
|
||||||
|
<artifactId>commons-io</artifactId>
|
||||||
|
<version>${commonsio.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
<version>${lombok.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>${commons.lang.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.squareup</groupId>
|
||||||
|
<artifactId>javapoet</artifactId>
|
||||||
|
<version>1.11.1</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
<!-- Test Dependencies -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>${junit-jupiter.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>${junit-jupiter.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-stdlib-jdk8</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-test</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.module</groupId>
|
||||||
|
<artifactId>jackson-module-kotlin</artifactId>
|
||||||
|
<version>2.9.9</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.beust</groupId>
|
||||||
|
<artifactId>jcommander</artifactId>
|
||||||
|
<version>1.78</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-api</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.codehaus.mojo</groupId>
|
||||||
|
<artifactId>build-helper-maven-plugin</artifactId>
|
||||||
|
<version>3.0.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>add-source</id>
|
||||||
|
<phase>generate-sources</phase>
|
||||||
|
<goals><goal>add-source</goal></goals>
|
||||||
|
<configuration>
|
||||||
|
<sources>
|
||||||
|
<source>src/main/stubs</source>
|
||||||
|
</sources>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<version>${maven-shade-plugin.version}</version>
|
||||||
|
<configuration>
|
||||||
|
<shadedArtifactAttached>true</shadedArtifactAttached>
|
||||||
|
<createDependencyReducedPom>false</createDependencyReducedPom>
|
||||||
|
<filters>
|
||||||
|
<filter>
|
||||||
|
<artifact>*:*</artifact>
|
||||||
|
<excludes>
|
||||||
|
<exclude>org/datanucleus/**</exclude>
|
||||||
|
<exclude>META-INF/*.SF</exclude>
|
||||||
|
<exclude>META-INF/*.DSA</exclude>
|
||||||
|
<exclude>META-INF/*.RSA</exclude>
|
||||||
|
</excludes>
|
||||||
|
</filter>
|
||||||
|
</filters>
|
||||||
|
|
||||||
|
</configuration>
|
||||||
|
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>shade</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<transformers>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
||||||
|
<resource>reference.conf</resource>
|
||||||
|
</transformer>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
|
||||||
|
</transformers>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-maven-plugin</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
<configuration>
|
||||||
|
<args>
|
||||||
|
<arg>-Xjsr305=strict</arg>
|
||||||
|
</args>
|
||||||
|
<compilerPlugins>
|
||||||
|
<plugin>spring</plugin>
|
||||||
|
<plugin>jpa</plugin>
|
||||||
|
</compilerPlugins>
|
||||||
|
</configuration>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-maven-allopen</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-maven-noarg</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>compile</id>
|
||||||
|
<goals> <goal>compile</goal> </goals>
|
||||||
|
<configuration>
|
||||||
|
<sourceDirs>
|
||||||
|
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/main/java</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
|
||||||
|
</sourceDirs>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>test-compile</id>
|
||||||
|
<goals> <goal>test-compile</goal> </goals>
|
||||||
|
<configuration>
|
||||||
|
<sourceDirs>
|
||||||
|
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/test/java</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
|
||||||
|
</sourceDirs>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-compiler-plugin</artifactId>
|
||||||
|
<version>3.5.1</version>
|
||||||
|
<executions>
|
||||||
|
<!-- Replacing default-compile as it is treated specially by maven -->
|
||||||
|
<execution>
|
||||||
|
<id>default-compile</id>
|
||||||
|
<phase>none</phase>
|
||||||
|
</execution>
|
||||||
|
<!-- Replacing default-testCompile as it is treated specially by maven -->
|
||||||
|
<execution>
|
||||||
|
<id>default-testCompile</id>
|
||||||
|
<phase>none</phase>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>java-compile</id>
|
||||||
|
<phase>compile</phase>
|
||||||
|
<goals> <goal>compile</goal> </goals>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>java-test-compile</id>
|
||||||
|
<phase>test-compile</phase>
|
||||||
|
<goals> <goal>testCompile</goal> </goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<source>${java.version}</source>
|
||||||
|
<target>${java.version}</target>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
</project>
|
|
@ -19,17 +19,13 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.nd4j.codegen.impl.java;
|
package org.nd4j.codegen.impl.java;
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.nd4j.codegen.api.Language;
|
import org.nd4j.codegen.api.Language;
|
||||||
import org.nd4j.codegen.api.Namespace;
|
|
||||||
import org.nd4j.codegen.api.NamespaceOps;
|
import org.nd4j.codegen.api.NamespaceOps;
|
||||||
import org.nd4j.codegen.api.Op;
|
|
||||||
import org.nd4j.codegen.api.generator.Generator;
|
import org.nd4j.codegen.api.generator.Generator;
|
||||||
import org.nd4j.codegen.api.generator.GeneratorConfig;
|
import org.nd4j.codegen.api.generator.GeneratorConfig;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
public class JavaPoetGenerator implements Generator {
|
public class JavaPoetGenerator implements Generator {
|
||||||
|
|
||||||
|
@ -40,12 +36,12 @@ public class JavaPoetGenerator implements Generator {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException {
|
public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException {
|
||||||
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.linalg.factory", StringUtils.EMPTY);
|
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.linalg.factory", StringUtils.EMPTY);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException {
|
public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException {
|
||||||
//throw new UnsupportedOperationException("Not yet implemented");
|
//throw new UnsupportedOperationException("Not yet implemented");
|
||||||
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.autodiff.samediff", StringUtils.EMPTY);
|
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.autodiff.samediff", StringUtils.EMPTY);
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,23 @@ fun SDLoss() = Namespace("Loss"){
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Op("ctcLoss") {
|
||||||
|
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
|
||||||
|
javaOpClass = "CtcLoss"
|
||||||
|
Input(NUMERIC, "targetLabels") { description = "Label array" }
|
||||||
|
Input(NUMERIC, "logitInput") { description = "Inputs" }
|
||||||
|
Input(NUMERIC, "targetLabelLengths") { description = "Length of the target label" }
|
||||||
|
Input(NUMERIC, "logitInputLengths") { description = "Length of the input"}
|
||||||
|
Output(NUMERIC, "output"){ description = "Ctc loss " }
|
||||||
|
Doc(Language.ANY, DocScope.ALL){
|
||||||
|
"""
|
||||||
|
CTC Loss: Connectionist Temporal Classification Loss. See:
|
||||||
|
https://dl.acm.org/citation.cfm?id=1143891
|
||||||
|
""".trimIndent()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Op("cosineDistance") {
|
Op("cosineDistance") {
|
||||||
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
|
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
|
||||||
javaOpClass = "CosineDistanceLoss"
|
javaOpClass = "CosineDistanceLoss"
|
||||||
|
|
|
@ -1,32 +1,26 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -2799,47 +2793,6 @@ public class SDBaseOps {
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Split the input in to a list of sub tensors
|
|
||||||
* @param input the input to split
|
|
||||||
* @param numSizeSplits the number of splits
|
|
||||||
* @param splitDim the dimension to split along
|
|
||||||
* @return the set of output variables
|
|
||||||
*/
|
|
||||||
public SDVariable[] split(SDVariable input,int numSizeSplits,int splitDim) {
|
|
||||||
SDValidation.validateNumerical("split",input);
|
|
||||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables();
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Split the input in to a list of sub tensors
|
|
||||||
* @param name the potential name of the input
|
|
||||||
* @param input the input to split
|
|
||||||
* @param numSizeSplits the number of splits
|
|
||||||
* @param splitDim the dimension to split along
|
|
||||||
* @return the set of output variables
|
|
||||||
*/
|
|
||||||
public SDVariable[] split(String name,SDVariable input,int numSizeSplits,int splitDim) {
|
|
||||||
SDValidation.validateNumerical("split",input);
|
|
||||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables();
|
|
||||||
SDVariable[] ret = new SDVariable[out.length];
|
|
||||||
AtomicInteger index = new AtomicInteger(0);
|
|
||||||
Arrays.stream(out).forEach(output -> {
|
|
||||||
if(index.get() < 1) {
|
|
||||||
ret[index.get()] = sd.updateVariableNameAndReference(output,name);
|
|
||||||
index.incrementAndGet();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ret[index.get()] = sd.updateVariableNameAndReference(output,name + ":" + index.get());
|
|
||||||
index.incrementAndGet();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
|
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
|
||||||
* input, but with the specified shape.<br>
|
* input, but with the specified shape.<br>
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
@ -250,7 +248,7 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||||
*
|
*
|
||||||
* @param x Input to be bit shifted (INT type)
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
@ -265,7 +263,7 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -348,7 +346,7 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||||
*
|
*
|
||||||
* @param x Input to be bit shifted (INT type)
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
@ -363,7 +361,7 @@ public class SDBitwise extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -42,8 +42,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - average pooling 2d<br>
|
* 2D Convolution layer operation - average pooling 2d<br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
* @return output Result after applying average pooling on the input (NUMERIC type)
|
* @return output Result after applying average pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -56,8 +55,7 @@ public class SDCNN extends SDOps {
|
||||||
* 2D Convolution layer operation - average pooling 2d<br>
|
* 2D Convolution layer operation - average pooling 2d<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
* @return output Result after applying average pooling on the input (NUMERIC type)
|
* @return output Result after applying average pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -70,9 +68,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - average pooling 3d <br>
|
* 3D convolution layer operation - average pooling 3d <br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling3DConfig Configuration Object
|
* @param Pooling3DConfig Configuration Object
|
||||||
* @return output after applying average pooling on the input (NUMERIC type)
|
* @return output after applying average pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -85,9 +81,7 @@ public class SDCNN extends SDOps {
|
||||||
* 3D convolution layer operation - average pooling 3d <br>
|
* 3D convolution layer operation - average pooling 3d <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling3DConfig Configuration Object
|
* @param Pooling3DConfig Configuration Object
|
||||||
* @return output after applying average pooling on the input (NUMERIC type)
|
* @return output after applying average pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -302,9 +296,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias <br>
|
* Convolution 3D operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
* @param Conv3DConfig Configuration Object
|
* @param Conv3DConfig Configuration Object
|
||||||
|
@ -322,9 +314,7 @@ public class SDCNN extends SDOps {
|
||||||
* Convolution 3D operation with optional bias <br>
|
* Convolution 3D operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
* @param Conv3DConfig Configuration Object
|
* @param Conv3DConfig Configuration Object
|
||||||
|
@ -342,9 +332,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias <br>
|
* Convolution 3D operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||||
* @param Conv3DConfig Configuration Object
|
* @param Conv3DConfig Configuration Object
|
||||||
* @return output Conv3d output variable (NUMERIC type)
|
* @return output Conv3d output variable (NUMERIC type)
|
||||||
|
@ -359,9 +347,7 @@ public class SDCNN extends SDOps {
|
||||||
* Convolution 3D operation with optional bias <br>
|
* Convolution 3D operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||||
* @param Conv3DConfig Configuration Object
|
* @param Conv3DConfig Configuration Object
|
||||||
* @return output Conv3d output variable (NUMERIC type)
|
* @return output Conv3d output variable (NUMERIC type)
|
||||||
|
@ -377,8 +363,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with optional bias<br>
|
* 2D deconvolution operation with optional bias<br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
* @param DeConv2DConfig Configuration Object
|
* @param DeConv2DConfig Configuration Object
|
||||||
|
@ -396,8 +381,7 @@ public class SDCNN extends SDOps {
|
||||||
* 2D deconvolution operation with optional bias<br>
|
* 2D deconvolution operation with optional bias<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
* @param DeConv2DConfig Configuration Object
|
* @param DeConv2DConfig Configuration Object
|
||||||
|
@ -415,8 +399,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with optional bias<br>
|
* 2D deconvolution operation with optional bias<br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||||
* @param DeConv2DConfig Configuration Object
|
* @param DeConv2DConfig Configuration Object
|
||||||
* @return output result of deconv2d op (NUMERIC type)
|
* @return output result of deconv2d op (NUMERIC type)
|
||||||
|
@ -432,8 +415,7 @@ public class SDCNN extends SDOps {
|
||||||
* 2D deconvolution operation with optional bias<br>
|
* 2D deconvolution operation with optional bias<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||||
* @param DeConv2DConfig Configuration Object
|
* @param DeConv2DConfig Configuration Object
|
||||||
* @return output result of deconv2d op (NUMERIC type)
|
* @return output result of deconv2d op (NUMERIC type)
|
||||||
|
@ -519,8 +501,7 @@ public class SDCNN extends SDOps {
|
||||||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||||
* = [mb, 2, 4, 4]<br>
|
* = [mb, 2, 4, 4]<br>
|
||||||
*
|
*
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
* @param blockSize Block size, in the height/width dimension
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
|
@ -537,8 +518,7 @@ public class SDCNN extends SDOps {
|
||||||
* = [mb, 2, 4, 4]<br>
|
* = [mb, 2, 4, 4]<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
* @param blockSize Block size, in the height/width dimension
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
|
@ -756,8 +736,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||||
*
|
*
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
*/
|
*/
|
||||||
public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) {
|
public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) {
|
||||||
|
@ -769,8 +748,7 @@ public class SDCNN extends SDOps {
|
||||||
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||||
*
|
*
|
||||||
* @param names names May be null. Arrays of names for the output variables.
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
*/
|
*/
|
||||||
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input,
|
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input,
|
||||||
|
@ -783,8 +761,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - max pooling 2d <br>
|
* 2D Convolution layer operation - max pooling 2d <br>
|
||||||
*
|
*
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -797,8 +774,7 @@ public class SDCNN extends SDOps {
|
||||||
* 2D Convolution layer operation - max pooling 2d <br>
|
* 2D Convolution layer operation - max pooling 2d <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -811,9 +787,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - max pooling 3d operation.<br>
|
* 3D convolution layer operation - max pooling 3d operation.<br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling3DConfig Configuration Object
|
* @param Pooling3DConfig Configuration Object
|
||||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -826,9 +800,7 @@ public class SDCNN extends SDOps {
|
||||||
* 3D convolution layer operation - max pooling 3d operation.<br>
|
* 3D convolution layer operation - max pooling 3d operation.<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling3DConfig Configuration Object
|
* @param Pooling3DConfig Configuration Object
|
||||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -841,8 +813,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with optional bias <br>
|
* Separable 2D convolution operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||||
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
|
@ -862,8 +833,7 @@ public class SDCNN extends SDOps {
|
||||||
* Separable 2D convolution operation with optional bias <br>
|
* Separable 2D convolution operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||||
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
|
@ -883,8 +853,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with optional bias <br>
|
* Separable 2D convolution operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||||
* @param Conv2DConfig Configuration Object
|
* @param Conv2DConfig Configuration Object
|
||||||
|
@ -902,8 +871,7 @@ public class SDCNN extends SDOps {
|
||||||
* Separable 2D convolution operation with optional bias <br>
|
* Separable 2D convolution operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||||
* @param Conv2DConfig Configuration Object
|
* @param Conv2DConfig Configuration Object
|
||||||
|
@ -964,8 +932,7 @@ public class SDCNN extends SDOps {
|
||||||
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||||
* = [mb, 2, 4, 4] <br>
|
* = [mb, 2, 4, 4] <br>
|
||||||
*
|
*
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
* @param blockSize Block size, in the height/width dimension
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
|
@ -982,8 +949,7 @@ public class SDCNN extends SDOps {
|
||||||
* = [mb, 2, 4, 4] <br>
|
* = [mb, 2, 4, 4] <br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
* @param blockSize Block size, in the height/width dimension
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -36,7 +36,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output loss variable (NUMERIC type)
|
* @return output loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights,
|
public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights,
|
||||||
|
@ -56,7 +56,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output loss variable (NUMERIC type)
|
* @return output loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions,
|
public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions,
|
||||||
|
@ -116,7 +116,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param dimension Dimension to perform the cosine distance over
|
* @param dimension Dimension to perform the cosine distance over
|
||||||
* @return output Cosine distance loss (NUMERIC type)
|
* @return output Cosine distance loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -141,7 +141,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param dimension Dimension to perform the cosine distance over
|
* @param dimension Dimension to perform the cosine distance over
|
||||||
* @return output Cosine distance loss (NUMERIC type)
|
* @return output Cosine distance loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -202,6 +202,49 @@ public class SDLoss extends SDOps {
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
|
||||||
|
* https://dl.acm.org/citation.cfm?id=1143891<br>
|
||||||
|
*
|
||||||
|
* @param targetLabels Label array (NUMERIC type)
|
||||||
|
* @param logitInput Inputs (NUMERIC type)
|
||||||
|
* @param targetLabelLengths Length of the target label (NUMERIC type)
|
||||||
|
* @param logitInputLengths Length of the input (NUMERIC type)
|
||||||
|
* @return output Ctc loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public SDVariable ctcLoss(SDVariable targetLabels, SDVariable logitInput,
|
||||||
|
SDVariable targetLabelLengths, SDVariable logitInputLengths) {
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable();
|
||||||
|
out.markAsLoss();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
|
||||||
|
* https://dl.acm.org/citation.cfm?id=1143891<br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param targetLabels Label array (NUMERIC type)
|
||||||
|
* @param logitInput Inputs (NUMERIC type)
|
||||||
|
* @param targetLabelLengths Length of the target label (NUMERIC type)
|
||||||
|
* @param logitInputLengths Length of the input (NUMERIC type)
|
||||||
|
* @return output Ctc loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public SDVariable ctcLoss(String name, SDVariable targetLabels, SDVariable logitInput,
|
||||||
|
SDVariable targetLabelLengths, SDVariable logitInputLengths) {
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
|
||||||
|
SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable();
|
||||||
|
out.markAsLoss();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hinge loss: a loss function used for training classifiers.<br>
|
* Hinge loss: a loss function used for training classifiers.<br>
|
||||||
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
||||||
|
@ -210,7 +253,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights,
|
public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights,
|
||||||
|
@ -232,7 +275,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions,
|
public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions,
|
||||||
|
@ -297,7 +340,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param delta Loss function delta value
|
* @param delta Loss function delta value
|
||||||
* @return output Huber loss (NUMERIC type)
|
* @return output Huber loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -324,7 +367,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param delta Loss function delta value
|
* @param delta Loss function delta value
|
||||||
* @return output Huber loss (NUMERIC type)
|
* @return output Huber loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -423,7 +466,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param epsilon epsilon
|
* @param epsilon epsilon
|
||||||
* @return output Log loss (NUMERIC type)
|
* @return output Log loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -445,7 +488,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param epsilon epsilon
|
* @param epsilon epsilon
|
||||||
* @return output Log loss (NUMERIC type)
|
* @return output Log loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -499,7 +542,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||||
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -521,7 +564,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||||
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -585,7 +628,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable, scalar output (NUMERIC type)
|
* @return output Loss variable, scalar output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions,
|
public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions,
|
||||||
|
@ -608,7 +651,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable, scalar output (NUMERIC type)
|
* @return output Loss variable, scalar output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions,
|
public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions,
|
||||||
|
@ -666,13 +709,13 @@ public class SDLoss extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||||
* this is the mean squared error loss function.<br>
|
* this is the mean squared error loss function.<br>
|
||||||
*
|
*
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights,
|
public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights,
|
||||||
|
@ -687,14 +730,14 @@ public class SDLoss extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||||
* this is the mean squared error loss function.<br>
|
* this is the mean squared error loss function.<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions,
|
public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions,
|
||||||
|
@ -709,7 +752,7 @@ public class SDLoss extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||||
* this is the mean squared error loss function.<br>
|
* this is the mean squared error loss function.<br>
|
||||||
*
|
*
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
|
@ -728,7 +771,7 @@ public class SDLoss extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||||
* this is the mean squared error loss function.<br>
|
* this is the mean squared error loss function.<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
|
@ -764,7 +807,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictionLogits Predictions array (NUMERIC type)
|
* @param predictionLogits Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -796,7 +839,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictionLogits Predictions array (NUMERIC type)
|
* @param predictionLogits Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -872,7 +915,7 @@ public class SDLoss extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
* otherwise, the output is a scalar.<br>
|
* otherwise, the output is a scalar.<br>
|
||||||
* <p><br>
|
* <p><br>
|
||||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
@ -884,7 +927,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||||
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -901,7 +944,7 @@ public class SDLoss extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
* otherwise, the output is a scalar.<br>
|
* otherwise, the output is a scalar.<br>
|
||||||
* <p><br>
|
* <p><br>
|
||||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
@ -914,7 +957,7 @@ public class SDLoss extends SDOps {
|
||||||
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||||
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -932,7 +975,7 @@ public class SDLoss extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
* otherwise, the output is a scalar.<br>
|
* otherwise, the output is a scalar.<br>
|
||||||
* <p><br>
|
* <p><br>
|
||||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
@ -959,7 +1002,7 @@ public class SDLoss extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
* otherwise, the output is a scalar.<br>
|
* otherwise, the output is a scalar.<br>
|
||||||
* <p><br>
|
* <p><br>
|
||||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -144,22 +144,22 @@ public class SDRNN extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
* SUPPORTS following data formats:\n<br>
|
* SUPPORTS following data formats:<br>
|
||||||
* for unidirectional: \n" +<br>
|
* for unidirectional:<br>
|
||||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
* for bidirectional:\n<br>
|
* for bidirectional:<br>
|
||||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||||
* SUPPORTS following direction modes:\n<br>
|
* SUPPORTS following direction modes:<br>
|
||||||
* FWD: forward<br>
|
* FWD: forward<br>
|
||||||
* BWD: backward<br>
|
* BWD: backward<br>
|
||||||
* BIDIR_SUM: bidirectional sum\n<br>
|
* BIDIR_SUM: bidirectional sum<br>
|
||||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
* BIDIR_CONCAT: bidirectional concat<br>
|
||||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||||
* You may use different gate configurations:<br>
|
* You may use different gate configurations:<br>
|
||||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
*
|
*
|
||||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
@ -180,22 +180,22 @@ public class SDRNN extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
* SUPPORTS following data formats:\n<br>
|
* SUPPORTS following data formats:<br>
|
||||||
* for unidirectional: \n" +<br>
|
* for unidirectional:<br>
|
||||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
* for bidirectional:\n<br>
|
* for bidirectional:<br>
|
||||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||||
* SUPPORTS following direction modes:\n<br>
|
* SUPPORTS following direction modes:<br>
|
||||||
* FWD: forward<br>
|
* FWD: forward<br>
|
||||||
* BWD: backward<br>
|
* BWD: backward<br>
|
||||||
* BIDIR_SUM: bidirectional sum\n<br>
|
* BIDIR_SUM: bidirectional sum<br>
|
||||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
* BIDIR_CONCAT: bidirectional concat<br>
|
||||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||||
* You may use different gate configurations:<br>
|
* You may use different gate configurations:<br>
|
||||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
*
|
*
|
||||||
* @param names names May be null. Arrays of names for the output variables.
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
|
@ -218,22 +218,22 @@ public class SDRNN extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
* SUPPORTS following data formats:\n<br>
|
* SUPPORTS following data formats:<br>
|
||||||
* for unidirectional: \n" +<br>
|
* for unidirectional:<br>
|
||||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
* for bidirectional:\n<br>
|
* for bidirectional:<br>
|
||||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||||
* SUPPORTS following direction modes:\n<br>
|
* SUPPORTS following direction modes:<br>
|
||||||
* FWD: forward<br>
|
* FWD: forward<br>
|
||||||
* BWD: backward<br>
|
* BWD: backward<br>
|
||||||
* BIDIR_SUM: bidirectional sum\n<br>
|
* BIDIR_SUM: bidirectional sum<br>
|
||||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
* BIDIR_CONCAT: bidirectional concat<br>
|
||||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||||
* You may use different gate configurations:<br>
|
* You may use different gate configurations:<br>
|
||||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
*
|
*
|
||||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
@ -248,22 +248,22 @@ public class SDRNN extends SDOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
* SUPPORTS following data formats:\n<br>
|
* SUPPORTS following data formats:<br>
|
||||||
* for unidirectional: \n" +<br>
|
* for unidirectional:<br>
|
||||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
* for bidirectional:\n<br>
|
* for bidirectional:<br>
|
||||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||||
* SUPPORTS following direction modes:\n<br>
|
* SUPPORTS following direction modes:<br>
|
||||||
* FWD: forward<br>
|
* FWD: forward<br>
|
||||||
* BWD: backward<br>
|
* BWD: backward<br>
|
||||||
* BIDIR_SUM: bidirectional sum\n<br>
|
* BIDIR_SUM: bidirectional sum<br>
|
||||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
* BIDIR_CONCAT: bidirectional concat<br>
|
||||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||||
* You may use different gate configurations:<br>
|
* You may use different gate configurations:<br>
|
||||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
*
|
*
|
||||||
* @param names names May be null. Arrays of names for the output variables.
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations */
|
||||||
public enum CellAct {
|
public enum CellAct {
|
||||||
TANH,
|
TANH,
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data format: "NCHW" or "NHWC" */
|
||||||
public enum DataFormat {
|
public enum DataFormat {
|
||||||
NCHW,
|
NCHW,
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations */
|
||||||
public enum GateAct {
|
public enum GateAct {
|
||||||
TANH,
|
TANH,
|
||||||
|
|
||||||
|
|
|
@ -1,32 +1,43 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||||
|
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||||
|
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||||
|
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||||
|
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||||
|
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||||
|
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */
|
||||||
public enum ImageResizeMethod {
|
public enum ImageResizeMethod {
|
||||||
ResizeBilinear, // as java require
|
ResizeBilinear,
|
||||||
ResizeNearest,
|
|
||||||
ResizeBicubic,
|
ResizeBicubic,
|
||||||
ResizeArea,
|
|
||||||
|
ResizeNearest,
|
||||||
|
|
||||||
ResizeGaussian,
|
ResizeGaussian,
|
||||||
ResizeLanczos3,
|
|
||||||
ResizeLanczos5,
|
ResizeLanczos5,
|
||||||
ResizeMitchellcubic;
|
|
||||||
|
ResizeMitchelcubic,
|
||||||
|
|
||||||
|
ResizeArea
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,25 +1,28 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* for unidirectional: TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
* NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br> for bidirectional:
|
||||||
|
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */
|
||||||
public enum LSTMDataFormat {
|
public enum LSTMDataFormat {
|
||||||
TNS,
|
TNS,
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,30 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* direction <br>
|
||||||
|
* FWD: 0 = fwd
|
||||||
|
* BWD: 1 = bwd
|
||||||
|
* BIDIR_SUM: 2 = bidirectional sum
|
||||||
|
* BIDIR_CONCAT: 3 = bidirectional concat
|
||||||
|
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
|
||||||
public enum LSTMDirectionMode {
|
public enum LSTMDirectionMode {
|
||||||
FWD,
|
FWD,
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations */
|
||||||
public enum OutAct {
|
public enum OutAct {
|
||||||
TANH,
|
TANH,
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,28 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.enums;
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The data format of the input. Input shape depends on data format (in config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, inSize]<br>
|
||||||
|
* NST -> [batchSize, inSize, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, inSize]<br> */
|
||||||
public enum RnnDataFormat {
|
public enum RnnDataFormat {
|
||||||
TNS,
|
TNS,
|
||||||
|
|
||||||
|
|
|
@ -20,29 +20,20 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.loss;
|
package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
|
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
|
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class CtcLoss extends BaseLoss {
|
public class CtcLoss extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
||||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
||||||
}
|
}
|
||||||
|
|
||||||
public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
|
|
||||||
LossReduce lossReduce) {
|
|
||||||
this(sameDiff, lossReduce, predictions, weights, label);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
|
||||||
super(lossReduce, predictions, weights, labels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CtcLoss(){ }
|
public CtcLoss(){ }
|
||||||
|
|
||||||
|
@ -55,6 +46,6 @@ public class CtcLoss extends BaseLoss {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
||||||
//No external gradient
|
//No external gradient
|
||||||
//Args are: predictions, weights, label
|
//Args are: predictions, weights, label
|
||||||
return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs();
|
return new CtcLossBp(sameDiff, arg(0), arg(1), arg(2),arg(3)).outputs();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,17 +20,17 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.loss.bp;
|
package org.nd4j.linalg.api.ops.impl.loss.bp;
|
||||||
|
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class CtcLossBp extends BaseLossBp {
|
public class CtcLossBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
public CtcLossBp(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
||||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
||||||
}
|
}
|
||||||
|
|
||||||
public CtcLossBp(){ }
|
public CtcLossBp(){ }
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
@ -134,7 +132,7 @@ public class NDBitwise {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||||
*
|
*
|
||||||
* @param x Input to be bit shifted (INT type)
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
@ -180,7 +178,7 @@ public class NDBitwise {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||||
*
|
*
|
||||||
* @param x Input to be bit shifted (INT type)
|
* @param x Input to be bit shifted (INT type)
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.enums.DataFormat;
|
import org.nd4j.enums.DataFormat;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,8 +41,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - average pooling 2d<br>
|
* 2D Convolution layer operation - average pooling 2d<br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
* @return output Result after applying average pooling on the input (NUMERIC type)
|
* @return output Result after applying average pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -54,9 +53,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - average pooling 3d <br>
|
* 3D convolution layer operation - average pooling 3d <br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling3DConfig Configuration Object
|
* @param Pooling3DConfig Configuration Object
|
||||||
* @return output after applying average pooling on the input (NUMERIC type)
|
* @return output after applying average pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -161,9 +158,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias <br>
|
* Convolution 3D operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
* @param Conv3DConfig Configuration Object
|
* @param Conv3DConfig Configuration Object
|
||||||
|
@ -180,9 +175,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias <br>
|
* Convolution 3D operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||||
* @param Conv3DConfig Configuration Object
|
* @param Conv3DConfig Configuration Object
|
||||||
* @return output Conv3d output variable (NUMERIC type)
|
* @return output Conv3d output variable (NUMERIC type)
|
||||||
|
@ -196,8 +189,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with optional bias<br>
|
* 2D deconvolution operation with optional bias<br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
* @param DeConv2DConfig Configuration Object
|
* @param DeConv2DConfig Configuration Object
|
||||||
|
@ -214,8 +206,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with optional bias<br>
|
* 2D deconvolution operation with optional bias<br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||||
* @param DeConv2DConfig Configuration Object
|
* @param DeConv2DConfig Configuration Object
|
||||||
* @return output result of deconv2d op (NUMERIC type)
|
* @return output result of deconv2d op (NUMERIC type)
|
||||||
|
@ -263,8 +254,7 @@ public class NDCNN {
|
||||||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||||
* = [mb, 2, 4, 4]<br>
|
* = [mb, 2, 4, 4]<br>
|
||||||
*
|
*
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
* @param blockSize Block size, in the height/width dimension
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
|
@ -373,8 +363,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||||
*
|
*
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
*/
|
*/
|
||||||
public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) {
|
public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) {
|
||||||
|
@ -385,8 +374,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - max pooling 2d <br>
|
* 2D Convolution layer operation - max pooling 2d <br>
|
||||||
*
|
*
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling2DConfig Configuration Object
|
* @param Pooling2DConfig Configuration Object
|
||||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -398,9 +386,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - max pooling 3d operation.<br>
|
* 3D convolution layer operation - max pooling 3d operation.<br>
|
||||||
*
|
*
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param Pooling3DConfig Configuration Object
|
* @param Pooling3DConfig Configuration Object
|
||||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -412,8 +398,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with optional bias <br>
|
* Separable 2D convolution operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||||
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
||||||
|
@ -432,8 +417,7 @@ public class NDCNN {
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with optional bias <br>
|
* Separable 2D convolution operation with optional bias <br>
|
||||||
*
|
*
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||||
* @param Conv2DConfig Configuration Object
|
* @param Conv2DConfig Configuration Object
|
||||||
|
@ -471,8 +455,7 @@ public class NDCNN {
|
||||||
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||||
* = [mb, 2, 4, 4] <br>
|
* = [mb, 2, 4, 4] <br>
|
||||||
*
|
*
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
* @param blockSize Block size, in the height/width dimension
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.enums.ImageResizeMethod;
|
import org.nd4j.enums.ImageResizeMethod;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.NDValidation;
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.NDValidation;
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
|
@ -35,7 +35,7 @@ public class NDLoss {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output loss variable (NUMERIC type)
|
* @return output loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
|
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
@ -71,7 +71,7 @@ public class NDLoss {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param dimension Dimension to perform the cosine distance over
|
* @param dimension Dimension to perform the cosine distance over
|
||||||
* @return output Cosine distance loss (NUMERIC type)
|
* @return output Cosine distance loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -104,6 +104,25 @@ public class NDLoss {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
|
||||||
|
* https://dl.acm.org/citation.cfm?id=1143891<br>
|
||||||
|
*
|
||||||
|
* @param targetLabels Label array (NUMERIC type)
|
||||||
|
* @param logitInput Inputs (NUMERIC type)
|
||||||
|
* @param targetLabelLengths Length of the target label (NUMERIC type)
|
||||||
|
* @param logitInputLengths Length of the input (NUMERIC type)
|
||||||
|
* @return output Ctc loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray ctcLoss(INDArray targetLabels, INDArray logitInput, INDArray targetLabelLengths,
|
||||||
|
INDArray logitInputLengths) {
|
||||||
|
NDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
|
||||||
|
NDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
|
||||||
|
NDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
|
||||||
|
NDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(targetLabels, logitInput, targetLabelLengths, logitInputLengths))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hinge loss: a loss function used for training classifiers.<br>
|
* Hinge loss: a loss function used for training classifiers.<br>
|
||||||
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
||||||
|
@ -112,7 +131,7 @@ public class NDLoss {
|
||||||
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
|
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
@ -152,7 +171,7 @@ public class NDLoss {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param delta Loss function delta value
|
* @param delta Loss function delta value
|
||||||
* @return output Huber loss (NUMERIC type)
|
* @return output Huber loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -204,7 +223,7 @@ public class NDLoss {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param epsilon epsilon
|
* @param epsilon epsilon
|
||||||
* @return output Log loss (NUMERIC type)
|
* @return output Log loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -237,7 +256,7 @@ public class NDLoss {
|
||||||
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||||
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -275,7 +294,7 @@ public class NDLoss {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable, scalar output (NUMERIC type)
|
* @return output Loss variable, scalar output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
@ -306,13 +325,13 @@ public class NDLoss {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||||
* this is the mean squared error loss function.<br>
|
* this is the mean squared error loss function.<br>
|
||||||
*
|
*
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
@ -325,7 +344,7 @@ public class NDLoss {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||||
* this is the mean squared error loss function.<br>
|
* this is the mean squared error loss function.<br>
|
||||||
*
|
*
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
|
@ -357,7 +376,7 @@ public class NDLoss {
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictionLogits Predictions array (NUMERIC type)
|
* @param predictionLogits Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -398,7 +417,7 @@ public class NDLoss {
|
||||||
/**
|
/**
|
||||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
* otherwise, the output is a scalar.<br>
|
* otherwise, the output is a scalar.<br>
|
||||||
* <p><br>
|
* <p><br>
|
||||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
@ -410,7 +429,7 @@ public class NDLoss {
|
||||||
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||||
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
* @return output Loss variable (NUMERIC type)
|
* @return output Loss variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
|
@ -425,7 +444,7 @@ public class NDLoss {
|
||||||
/**
|
/**
|
||||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
* otherwise, the output is a scalar.<br>
|
* otherwise, the output is a scalar.<br>
|
||||||
* <p><br>
|
* <p><br>
|
||||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.enums.PartitionMode;
|
import org.nd4j.enums.PartitionMode;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.enums.PadMode;
|
import org.nd4j.enums.PadMode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
|
@ -85,22 +85,22 @@ public class NDRNN {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
* SUPPORTS following data formats:\n<br>
|
* SUPPORTS following data formats:<br>
|
||||||
* for unidirectional: \n" +<br>
|
* for unidirectional:<br>
|
||||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
* for bidirectional:\n<br>
|
* for bidirectional:<br>
|
||||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||||
* SUPPORTS following direction modes:\n<br>
|
* SUPPORTS following direction modes:<br>
|
||||||
* FWD: forward<br>
|
* FWD: forward<br>
|
||||||
* BWD: backward<br>
|
* BWD: backward<br>
|
||||||
* BIDIR_SUM: bidirectional sum\n<br>
|
* BIDIR_SUM: bidirectional sum<br>
|
||||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
* BIDIR_CONCAT: bidirectional concat<br>
|
||||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||||
* You may use different gate configurations:<br>
|
* You may use different gate configurations:<br>
|
||||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
*
|
*
|
||||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
@ -121,22 +121,22 @@ public class NDRNN {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
* SUPPORTS following data formats:\n<br>
|
* SUPPORTS following data formats:<br>
|
||||||
* for unidirectional: \n" +<br>
|
* for unidirectional:<br>
|
||||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
* for bidirectional:\n<br>
|
* for bidirectional:<br>
|
||||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||||
* SUPPORTS following direction modes:\n<br>
|
* SUPPORTS following direction modes:<br>
|
||||||
* FWD: forward<br>
|
* FWD: forward<br>
|
||||||
* BWD: backward<br>
|
* BWD: backward<br>
|
||||||
* BIDIR_SUM: bidirectional sum\n<br>
|
* BIDIR_SUM: bidirectional sum<br>
|
||||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
* BIDIR_CONCAT: bidirectional concat<br>
|
||||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||||
* You may use different gate configurations:<br>
|
* You may use different gate configurations:<br>
|
||||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
*
|
*
|
||||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
/*
|
/*******************************************************************************
|
||||||
* ******************************************************************************
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
* *
|
*
|
||||||
* *
|
* This program and the accompanying materials are made available under the
|
||||||
* * This program and the accompanying materials are made available under the
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
*
|
||||||
* *
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* * See the NOTICE file distributed with this work for additional
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* * information regarding copyright ownership.
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
* License for the specific language governing permissions and limitations
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* under the License.
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
*
|
||||||
* * License for the specific language governing permissions and limitations
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
* * under the License.
|
******************************************************************************/
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory.ops;
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
Loading…
Reference in New Issue