From a6f9aa2601e55b9eda8ff6bf26bb0bfbd26875ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Mon, 31 Oct 2022 20:39:27 +0100 Subject: [PATCH 1/8] Tensorflow pluggable device --- Cargo.toml | 6 +- tensorflow-sys/Cargo.toml | 2 + tensorflow-sys/README.md | 1 + tensorflow-sys/generate_bindgen_rs.sh | 5 + tensorflow-sys/src/experimental/c_api.rs | 113 +++++++++++++++++++++++ tensorflow-sys/src/experimental/mod.rs | 5 + tensorflow-sys/src/lib.rs | 5 + tensorflow-sys/tests/lib.rs | 20 ++++ 8 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 tensorflow-sys/src/experimental/c_api.rs create mode 100644 tensorflow-sys/src/experimental/mod.rs diff --git a/Cargo.toml b/Cargo.toml index aa417279ba..0f9d5bbf31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "tensorflow" version = "0.19.1" -authors = ["Adam Crume "] +authors = [ + "Adam Crume ", + "Maciej Maślanka ", +] description = "Rust language bindings for TensorFlow." license = "Apache-2.0" keywords = ["TensorFlow", "bindings"] @@ -44,6 +47,7 @@ tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"] tensorflow_unstable = [] tensorflow_runtime_linking = ["tensorflow-sys-runtime"] eager = ["tensorflow-sys/eager"] +experimental = ["tensorflow-sys/experimental"] # This is for testing purposes; users should not use this. examples_system_alloc = ["tensorflow-sys/examples_system_alloc"] private-docs-rs = ["tensorflow-sys/private-docs-rs"] # DO NOT RELY ON THIS diff --git a/tensorflow-sys/Cargo.toml b/tensorflow-sys/Cargo.toml index 061e3852dc..cec8e8d126 100644 --- a/tensorflow-sys/Cargo.toml +++ b/tensorflow-sys/Cargo.toml @@ -5,6 +5,7 @@ license = "Apache-2.0" authors = [ "Adam Crume ", "Ivan Ukhov ", + "Maciej Maślanka ", ] description = "The package provides bindings to TensorFlow." documentation = "https://tensorflow.github.io/rust" @@ -36,6 +37,7 @@ zip = "0.6.2" [features] tensorflow_gpu = [] eager = [] +experimental = [] # This is for testing purposes; users should not use this. examples_system_alloc = [] private-docs-rs = [] # DO NOT RELY ON THIS diff --git a/tensorflow-sys/README.md b/tensorflow-sys/README.md index 8b28a6ca1d..0b1626b9b7 100644 --- a/tensorflow-sys/README.md +++ b/tensorflow-sys/README.md @@ -59,6 +59,7 @@ compiled library will be picked up. **macOS Note**: Via [Homebrew](https://brew.sh/), you can just run `brew install libtensorflow`. +[tensorflow metal plugin]: https://developer.apple.com/metal/tensorflow-plugin/ ## Resources diff --git a/tensorflow-sys/generate_bindgen_rs.sh b/tensorflow-sys/generate_bindgen_rs.sh index 6ce66390ee..1c9883bf98 100755 --- a/tensorflow-sys/generate_bindgen_rs.sh +++ b/tensorflow-sys/generate_bindgen_rs.sh @@ -18,3 +18,8 @@ bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --all cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}" echo ${cmd} ${cmd} + +bindgen_options_experimental="--no-derive-copy --allowlist-function TF_LoadPluggableDeviceLibrary --allowlist-function TF_DeletePluggableDeviceLibraryHandle --allowlist-var TF_Buffer* --allowlist-type TF_ShapeAndTypeList --allowlist-type TF_ShapeAndType --allowlist-type TF_CheckpointReader --allowlist-type TF_AttrBuilder --size_t-is-usize --default-enum-style=rust --generate-inline-functions --blocklist-type TF_Library --blocklist-type TF_DataType --blocklist-type TF_Status" +cmd="bindgen ${bindgen_options_experimental} ${include_dir}/tensorflow/c/c_api_experimental.h --output src/experimental/c_api.rs -- -I ${include_dir}" +echo ${cmd} +${cmd} diff --git a/tensorflow-sys/src/experimental/c_api.rs b/tensorflow-sys/src/experimental/c_api.rs new file mode 100644 index 0000000000..44449340a2 --- /dev/null +++ b/tensorflow-sys/src/experimental/c_api.rs @@ -0,0 +1,113 @@ +/* automatically generated by rust-bindgen 0.61.0 */ + +#[repr(C)] +#[derive(Debug)] +pub struct TF_CheckpointReader { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug)] +pub struct TF_AttrBuilder { + _unused: [u8; 0], +} +#[repr(C)] +pub struct TF_ShapeAndType { + pub num_dims: ::std::os::raw::c_int, + pub dims: *mut i64, + pub dtype: TF_DataType, +} +#[test] +fn bindgen_test_layout_TF_ShapeAndType() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 24usize, + concat!("Size of: ", stringify!(TF_ShapeAndType)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(TF_ShapeAndType)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).num_dims) as usize - ptr as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndType), + "::", + stringify!(num_dims) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).dims) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndType), + "::", + stringify!(dims) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).dtype) as usize - ptr as usize }, + 16usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndType), + "::", + stringify!(dtype) + ) + ); +} +#[repr(C)] +#[derive(Debug)] +pub struct TF_ShapeAndTypeList { + pub num_items: ::std::os::raw::c_int, + pub items: *mut TF_ShapeAndType, +} +#[test] +fn bindgen_test_layout_TF_ShapeAndTypeList() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(TF_ShapeAndTypeList)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(TF_ShapeAndTypeList)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).num_items) as usize - ptr as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndTypeList), + "::", + stringify!(num_items) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).items) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndTypeList), + "::", + stringify!(items) + ) + ); +} +extern "C" { + pub fn TF_LoadPluggableDeviceLibrary( + library_filename: *const ::std::os::raw::c_char, + status: *mut TF_Status, + ) -> *mut TF_Library; +} +extern "C" { + pub fn TF_DeletePluggableDeviceLibraryHandle(lib_handle: *mut TF_Library); +} diff --git a/tensorflow-sys/src/experimental/mod.rs b/tensorflow-sys/src/experimental/mod.rs new file mode 100644 index 0000000000..c036b73c19 --- /dev/null +++ b/tensorflow-sys/src/experimental/mod.rs @@ -0,0 +1,5 @@ +use crate::{ + TF_Library, TF_Status, TF_DataType, +}; + +include!("c_api.rs"); diff --git a/tensorflow-sys/src/lib.rs b/tensorflow-sys/src/lib.rs index 2e3f1999d7..87204b27db 100644 --- a/tensorflow-sys/src/lib.rs +++ b/tensorflow-sys/src/lib.rs @@ -11,3 +11,8 @@ include!("c_api.rs"); pub use crate::TF_AttrType::*; pub use crate::TF_Code::*; pub use crate::TF_DataType::*; + +#[cfg(feature = "experimental")] +mod experimental; +#[cfg(feature = "experimental")] +pub use experimental::*; diff --git a/tensorflow-sys/tests/lib.rs b/tensorflow-sys/tests/lib.rs index 946b9844db..b64471a860 100644 --- a/tensorflow-sys/tests/lib.rs +++ b/tensorflow-sys/tests/lib.rs @@ -46,3 +46,23 @@ fn tfe_tensor_handle() { ffi::TF_DeleteTensor(tf_tensor); } } + +/// Test that the experimental API works. +#[cfg(feature = "experimental")] +#[test] +fn load_plugable_device() { + let c_filename = std::ffi::CString::new("libmetal_plugin.dylib").expect("CString::new failed"); + unsafe { + let raw_status = ffi::TF_NewStatus(); + ffi::TF_LoadPluggableDeviceLibrary(c_filename.as_ptr(), raw_status); + if ffi::TF_GetCode(raw_status) != ffi::TF_OK { + panic!( + "{}", + std::ffi::CStr::from_ptr(ffi::TF_Message(raw_status)) + .to_string_lossy() + .into_owned() + ); + } + ffi::TF_DeleteStatus(raw_status); + }; +} From 8becc29a1474ccb90addf94ecef00bdb7989902c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Tue, 1 Nov 2022 18:04:32 +0100 Subject: [PATCH 2/8] delete lib handle --- tensorflow-sys/tests/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow-sys/tests/lib.rs b/tensorflow-sys/tests/lib.rs index b64471a860..a9bbebc523 100644 --- a/tensorflow-sys/tests/lib.rs +++ b/tensorflow-sys/tests/lib.rs @@ -54,7 +54,7 @@ fn load_plugable_device() { let c_filename = std::ffi::CString::new("libmetal_plugin.dylib").expect("CString::new failed"); unsafe { let raw_status = ffi::TF_NewStatus(); - ffi::TF_LoadPluggableDeviceLibrary(c_filename.as_ptr(), raw_status); + let lib_handle = ffi::TF_LoadPluggableDeviceLibrary(c_filename.as_ptr(), raw_status); if ffi::TF_GetCode(raw_status) != ffi::TF_OK { panic!( "{}", @@ -63,6 +63,7 @@ fn load_plugable_device() { .into_owned() ); } + ffi::TF_DeletePluggableDeviceLibraryHandle(lib_handle); ffi::TF_DeleteStatus(raw_status); }; } From 481a65ade495b77d304d3bb53382df1db7220cfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Fri, 7 Apr 2023 21:17:25 +0200 Subject: [PATCH 3/8] install dependencies for macOS --- .github/workflows/ci.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d1808b393..04a4b21535 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,6 +32,30 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.8 + - name: Check macOS architecture + if: matrix.os == 'macos-latest' + id: check-arch + run: | + arch_name="$(uname -m)" + echo "Detected architecture: $arch_name" + echo "::set-output name=architecture::$arch_name" + - name: Setup environment for Apple Silicon + if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'arm64' + run: | + bash ~/miniconda.sh -b -p $HOME/miniconda + source ~/miniconda/bin/activate + conda install -c apple tensorflow-deps + - name: Setup environment for AMD + if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'x86_64' + run: | + python3 -m venv ~/venv-metal + source ~/venv-metal/bin/activate + python -m pip install -U pip + - name: Setup environment for macOS + if: matrix.os == 'macos-latest' + run: | + python -m pip install tensorflow-macos + python -m pip install tensorflow-metal # Install pip and pytest - name: Install dependencies run: | @@ -40,6 +64,7 @@ jobs: - name: Execute test-all run: ./test-all shell: bash + # clippy: # runs-on: ubuntu-latest # strategy: From 8b815867423d3bec23af78e9dbdd6173c722d822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Fri, 7 Apr 2023 21:18:15 +0200 Subject: [PATCH 4/8] run plugable device test for macOS --- tensorflow-sys/tests/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-sys/tests/lib.rs b/tensorflow-sys/tests/lib.rs index a9bbebc523..d8ccea75da 100644 --- a/tensorflow-sys/tests/lib.rs +++ b/tensorflow-sys/tests/lib.rs @@ -48,7 +48,7 @@ fn tfe_tensor_handle() { } /// Test that the experimental API works. -#[cfg(feature = "experimental")] +#[cfg(all(feature = "experimental", target_os = "macos"))] #[test] fn load_plugable_device() { let c_filename = std::ffi::CString::new("libmetal_plugin.dylib").expect("CString::new failed"); From 392e1d38a972021e2c190dc654bc1a3ad9894743 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Fri, 7 Apr 2023 21:42:54 +0200 Subject: [PATCH 5/8] add quotation mark --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 78a51dd3a7..2f5ccea2e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "tensorflow" version = "0.20.0" authors = [ - "Adam Crume , + "Adam Crume ", "Maciej Maślanka " ] description = "Rust language bindings for TensorFlow." From 9a61696c62761b8741bddd9fa1498e33192523f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Fri, 7 Apr 2023 21:43:29 +0200 Subject: [PATCH 6/8] fix dependency name --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04a4b21535..9382597a8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: - name: Setup environment for macOS if: matrix.os == 'macos-latest' run: | - python -m pip install tensorflow-macos + python -m pip install tensorflow python -m pip install tensorflow-metal # Install pip and pytest - name: Install dependencies From 8e2567ca2fbc21e5869f058339d351f334078c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Fri, 7 Apr 2023 21:57:13 +0200 Subject: [PATCH 7/8] crate fix --- tensorflow-sys/src/experimental/mod.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow-sys/src/experimental/mod.rs b/tensorflow-sys/src/experimental/mod.rs index c036b73c19..da35649e0a 100644 --- a/tensorflow-sys/src/experimental/mod.rs +++ b/tensorflow-sys/src/experimental/mod.rs @@ -1,5 +1,3 @@ -use crate::{ - TF_Library, TF_Status, TF_DataType, -}; +use crate::{TF_DataType, TF_Library, TF_Status}; include!("c_api.rs"); From bccf2879f1f236f68da6f08895eeadf4b323454d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Ma=C5=9Blanka?= Date: Fri, 7 Apr 2023 22:21:29 +0200 Subject: [PATCH 8/8] system_version_compat=0 --- .github/workflows/ci.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9382597a8d..1da2823b4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,17 +45,14 @@ jobs: bash ~/miniconda.sh -b -p $HOME/miniconda source ~/miniconda/bin/activate conda install -c apple tensorflow-deps + SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal - name: Setup environment for AMD if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'x86_64' run: | python3 -m venv ~/venv-metal source ~/venv-metal/bin/activate python -m pip install -U pip - - name: Setup environment for macOS - if: matrix.os == 'macos-latest' - run: | - python -m pip install tensorflow - python -m pip install tensorflow-metal + SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal # Install pip and pytest - name: Install dependencies run: |