-
Notifications
You must be signed in to change notification settings - Fork 74.7k
[Kernel C API] Implementation of variable ops RFC. #49717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel C API] Implementation of variable ops RFC. #49717
Conversation
tensorflow/c/kernels.cc
Outdated
@@ -551,3 +554,257 @@ TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype, | |||
} | |||
return tf_tensor; | |||
} | |||
|
|||
tensorflow::Status EnsureSparseVariableAccess(TF_OpKernelContext* ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am moving these to c_api_experimental file (ran into some build issues). I wanted to get started with PR, to get feedback.
tensorflow/c/kernels.cc
Outdated
var->tensor()->shape(), &tmp, attr)); | ||
tensorflow::Status s; | ||
TF_Tensor *tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
TF_Tensor *tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at the end of the function, TF_DeleteTensor() is needed, TF_TensorFromTensor will new a TF_Tensor struct and will cause memory leak if TF_DeleteTensor() is not invoked.
/ Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
*status = tensorflow::Status::OK();
if (!src.IsInitialized()) {
*status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
return nullptr;
}
if (src.NumElements() == 0) {
return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
}
if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) {
*status = InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error.");
return nullptr;
}
const string str =
src.scalar<tensorflow::ResourceHandle>()().SerializeAsString();
TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
return t;
}
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return new TF_Tensor{new tensorflow::TensorInterface(std::move(tensor))};
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tensorflow/c/kernels.cc
Outdated
context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); | ||
tensorflow::Status s; | ||
TF_Tensor *tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
TF_Tensor *tf_tensor = TF_TensorFromTensor(*tensor, &s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check the two status ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can return the status back to caller of the PrepareUpdateVariable().
The implementation for the Variable Ops RFC. https://github.com/tensorflow/community/blob/master/rfcs/20210504-kernel-extension-variable-ops.md
fc79d57
to
cc6a3f1
Compare
tensorflow/c/kernels.h
Outdated
// &total_size)). | ||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape( | ||
TF_OpKernelConstruction* ctx, | ||
const char* attr_name, int64_t* values, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Rename values
to dims
and max_vals
to num_dims
to be consistent with TF_GraphGetTensorShape
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tensorflow/c/kernels.cc
Outdated
@@ -551,3 +571,301 @@ TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype, | |||
} | |||
return tf_tensor; | |||
} | |||
|
|||
tensorflow::Status EnsureSparseVariableAccess(TF_OpKernelContext* ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to share code with the existing impl for EnsureSparseVariableAccess
in training_op_helpers.h
? Same comment for other helpers below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @saxenasaurabh for the review. I have used the helper functions where possible like LookupResource
. With EnsureSparseVariableAccess , there is Device dependency which we are passing in as Copy functors. I do agree there is duplication of code which can be avoided. One possibility is to refactor the core helper functions to remove this dependency and we can adopt it here. Maybe we can do it as a followup cleanup as it will require invasive changes in the core which can break things. What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Please try to clean this is up as a follow-up. That would give test coverage for free as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, will do.
tensorflow/c/kernels.cc
Outdated
tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); | ||
TF_Tensor *tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
TF_Tensor *tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); | ||
copyFunc(ctx, tf_tensor, tf_tmp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are expecting plugin to call TF_DeleteTensor in copyFunc. This would be more in line with rest of the TF, where we release the tensors in the Compute.
tensorflow/c/kernels.h
Outdated
TF_OpKernelContext* ctx, | ||
int input, | ||
bool lock_held, | ||
bool isVariantType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to expose isVariantType
in the API or can that be inferred from the tensor? I believe this is equivalent to TF_TensorType(tensor) == TF_VARIANT
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@saxenasaurabh thanks! (sorry for the delay), you are right , we can probably skip the isVariantType in the API. Currently in our metal-plugin
release we are using this API. If its not too much of a concern, can we keep this way?
@@ -113,6 +114,19 @@ struct TF_OperationDescription { | |||
std::set<tensorflow::string> colocation_constraints; | |||
}; | |||
|
|||
struct TF_VariableInputLockHolder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be in c_api_experimental as well?
The implementation for the Variable Ops RFC.
https://github.com/tensorflow/community/blob/master/rfcs/20210504-kernel-extension-variable-ops.md
@penpornk , @reedwm , @saxenasaurabh , @jzhoulon