-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-Tensor Input in Servo-Beam #10
base: master
Are you sure you want to change the base?
Conversation
…mn opption and tests
Integrate Arrow as internal processing container
tfx_bsl/beam/run_inference.py
Outdated
@@ -383,14 +504,20 @@ def setup(self): | |||
# user agent once custom header is supported in googleapiclient. | |||
self._api_client = discovery.build('ml', 'v1') | |||
|
|||
def _extract_from_recordBatch(self, elements: pa.RecordBatch): | |||
serialized_examples = bsl_util.ExtractSerializedExampleFromRecordBatch(elements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this is the same in Batch and Remote DoFn. Maybe extract this out to Base, and only get model_input in _extract_from_recordBatch?
tfx_bsl/beam/run_inference.py
Outdated
) -> Mapping[Text, np.ndarray]: | ||
self._check_elements(elements) | ||
outputs = self._run_tf_operations(elements) | ||
self, tensors: Mapping[Any, Any]) -> Mapping[Text, np.ndarray]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment on what's expected in tensors. And is the Mapping key a Text?
self, elements: Mapping[Any, Any], | ||
outputs: Mapping[Text, np.ndarray] | ||
) -> Iterable[Tuple[Union[str, bytes], classification_pb2.Classifications]]: | ||
serialized_examples, = elements.values() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove ','
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't give the right answer
self, elements: Mapping[Any, Any], | ||
outputs: Mapping[Text, np.ndarray] | ||
) -> Iterable[Tuple[Union[str, bytes], classification_pb2.Classifications]]: | ||
serialized_examples, = elements.values() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is element.values serialized examples?
raise ValueError('Expected to have one name and one alias per tensor') | ||
|
||
include_request = True | ||
if len(input_tensor_names) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the determination of single input string tensor in a internal utility function inside of BaseDoFn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input tensor names is not in baseDoFn
|
||
include_request = True | ||
if len(input_tensor_names) == 1: | ||
serialized_examples, = elements.values() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also check the type of elements.values is string/bytes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's checked in extract form record batch
tfx_bsl/beam/run_inference.py
Outdated
else: | ||
input_tensor_proto.tensor_shape.dim.add().size = len(elements[tensor_name][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the dim size is len(elements[tensor_name][0]) instead of:
for s in elements[tensor_name][0].shape:
input_tensor_proto.tensor_shape.dim.add().size = s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have an nd.array, I dont think we will have shape parameter
tfx_bsl/beam/run_inference.py
Outdated
for alias, tensor_name in zip(input_tensor_alias, input_tensor_names): | ||
input_tensor_proto = predict_log_tmpl.request.inputs[alias] | ||
input_tensor_proto.dtype = tf.as_dtype(input_tensor_types[alias]).as_datatype_enum | ||
if len(input_tensor_alias) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the single input case be handled separately?
tfx_bsl/beam/run_inference.py
Outdated
alias = input_tensor_alias[0] | ||
predict_log.request.inputs[alias].string_val.append(process_elements[i]) | ||
else: | ||
for alias, tensor_name in zip(input_tensor_alias, input_tensor_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct given it's already in the loop of alias, tensor_name
) -> Iterable[Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]]: | ||
self, elements: Mapping[Any, Any], | ||
outputs: Mapping[Text, np.ndarray] | ||
) -> Iterable[Tuple[Union[str, bytes], inference_pb2.MultiInferenceResponse]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this just be bytes instead of Union[str, bytes] ?
str is the same as 'bytes' in py2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to make sure it's compatible with py2
|
||
model_input = None | ||
if (len(self._io_tensor_spec.input_tensor_names) == 1): | ||
model_input = {self._io_tensor_spec.input_tensor_names[0]: serialized_examples} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just leave this in _BaseBatchsavedModelDoFn and move the rest to _BatchPredictDoFn?
tfx_bsl/public/beam/run_inference.py
Outdated
|
||
Args: | ||
examples: A PCollection containing examples. | ||
inference_spec_type: Model inference endpoint. | ||
Schema [optional]: required for models that requires |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention this is only available for Predict method.
tfx_bsl/beam/bsl_util.py
Outdated
|
||
_KERAS_INPUT_SUFFIX = '_input' | ||
|
||
def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExtractSerializedExamplesFromRecordBatch
def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]: | ||
serialized_examples = None | ||
for column_name, column_array in zip(elements.schema.names, elements.columns): | ||
if column_name == _RECORDBATCH_COLUMN: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should _RECORDBATCH_COLUMN be passed an an argument to the API?
If we use a constant here, it would mean users would have to use this same constant when creating the TFXIO.
|
||
@beam.ptransform_fn | ||
@beam.typehints.with_input_types(Union[tf.train.Example, | ||
tf.train.SequenceExample]) | ||
@beam.typehints.with_input_types(tf.train.Example) | ||
@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) | ||
def RunInference( # pylint: disable=invalid-name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the long term plan to deprecate the tf.example API? And only have a record batch API?
If so, mention it in a comment
if prepare_instances_serialized: | ||
return [{'b64': base64.b64encode(value).decode()} for value in df[_RECORDBATCH_COLUMN]] | ||
else: | ||
as_binary = df.columns.str.endswith("_bytes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the name end with "_bytes"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User specified byte columns, it's consistent with the original implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is required by cloud ai platform to indicate the bytes feature with '_bytes' suffix.
@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) | ||
def RunInferenceImpl( # pylint: disable=invalid-name | ||
def RunInferenceOnExamples( # pylint: disable=invalid-name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use the first option of public API here to have a polymorphic RunInference and RunInferenceImpl.
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
|
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
Internally uses Arrow RecordBatch for processing, supports multi-tensor input