forked from t13m/kaldi-readers-for-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshape-funcs.cc
27 lines (23 loc) · 918 Bytes
/
shape-funcs.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "shape-funcs.hh"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/framework/common_shape_fns.h"
namespace shape_util {
using tensorflow::shape_inference::DimensionHandle;
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
tensorflow::Status ScalarInputsAndOutputs(InferenceContext *c) {
ShapeHandle unused;
for (int i = 0; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->Scalar());
}
return tensorflow::Status::OK();
}
tensorflow::Status TwoElementOutput(InferenceContext *c) {
c->set_output(0, c->Vector(2));
return tensorflow::Status::OK();
}
} // namespace shape_util