Skip to content

Commit 8da6f37

Browse files
authored
Merge pull request #34822 from lgeiger/r2.1-cherry-pick-pip-package-gen
[r2.1 cherry-pick] Fix pip package API generation
2 parents 03b0dcb + 33340d1 commit 8da6f37

6 files changed

+25
-29
lines changed

tensorflow/api_template.__init__.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ def _running_from_pip_package():
119119
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
120120

121121
if _running_from_pip_package():
122-
for s in _site_packages_dirs:
122+
for _s in _site_packages_dirs:
123123
# TODO(gunan): Add sanity checks to loaded modules here.
124-
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
125-
if _fi.file_exists(plugin_dir):
126-
_ll.load_library(plugin_dir)
124+
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
125+
if _fi.file_exists(_plugin_dir):
126+
_ll.load_library(_plugin_dir)
127127

128128
# Add module aliases
129129
if hasattr(_current_module, 'keras'):
@@ -136,3 +136,5 @@ def _running_from_pip_package():
136136
setattr(_current_module, "optimizers", optimizers)
137137
setattr(_current_module, "initializers", initializers)
138138
# pylint: enable=undefined-variable
139+
140+
# __all__ PLACEHOLDER

tensorflow/api_template_v1.__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ def _running_from_pip_package():
132132
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
133133

134134
if _running_from_pip_package():
135-
for s in _site_packages_dirs:
135+
for _s in _site_packages_dirs:
136136
# TODO(gunan): Add sanity checks to loaded modules here.
137-
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
138-
if _fi.file_exists(plugin_dir):
139-
_ll.load_library(plugin_dir)
137+
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
138+
if _fi.file_exists(_plugin_dir):
139+
_ll.load_library(_plugin_dir)
140140

141+
# __all__ PLACEHOLDER

tensorflow/python/tools/api/generator/create_python_api.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,12 @@ def build(self):
243243
# from it using * import. Don't need this for lazy_loading because the
244244
# underscore symbols are already included in __all__ when passed in and
245245
# handled by TFModuleWrapper.
246+
root_module_footer = ''
246247
if not self._lazy_loading:
247248
underscore_names_str = ', '.join(
248249
'\'%s\'' % name for name in self._underscore_names_in_root)
249250

250-
module_text_map[''] = module_text_map.get('', '') + '''
251+
root_module_footer = '''
251252
_names_with_underscore = [%s]
252253
__all__ = [_s for _s in dir() if not _s.startswith('_')]
253254
__all__.extend([_s for _s in _names_with_underscore])
@@ -273,7 +274,7 @@ def build(self):
273274
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
274275
dest_module, public_apis_name, deprecation, has_lite)
275276

276-
return module_text_map, footer_text_map
277+
return module_text_map, footer_text_map, root_module_footer
277278

278279
def format_import(self, source_module_name, source_name, dest_name):
279280
"""Formats import statement.
@@ -620,7 +621,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
620621
os.makedirs(os.path.dirname(file_path))
621622
open(file_path, 'a').close()
622623

623-
module_text_map, deprecation_footer_map = get_api_init_text(
624+
(
625+
module_text_map,
626+
deprecation_footer_map,
627+
root_module_footer,
628+
) = get_api_init_text(
624629
packages, output_package, api_name,
625630
api_version, compat_api_versions, lazy_loading, use_relative_imports)
626631

@@ -652,6 +657,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
652657
with open(root_init_template, 'r') as root_init_template_file:
653658
contents = root_init_template_file.read()
654659
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
660+
contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer)
655661
elif module in compat_module_to_template:
656662
# Read base init file for compat module
657663
with open(compat_module_to_template[module], 'r') as init_template_file:

tensorflow/python/tools/api/generator/create_python_api_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def tearDown(self):
6262
del sys.modules[_MODULE_NAME]
6363

6464
def testFunctionImportIsAdded(self):
65-
imports, _ = create_python_api.get_api_init_text(
65+
imports, _, _ = create_python_api.get_api_init_text(
6666
packages=[create_python_api._DEFAULT_PACKAGE],
6767
output_package='tensorflow',
6868
api_name='tensorflow',
@@ -97,7 +97,7 @@ def testFunctionImportIsAdded(self):
9797
msg='compat.v1 in %s' % str(imports.keys()))
9898

9999
def testClassImportIsAdded(self):
100-
imports, _ = create_python_api.get_api_init_text(
100+
imports, _, _ = create_python_api.get_api_init_text(
101101
packages=[create_python_api._DEFAULT_PACKAGE],
102102
output_package='tensorflow',
103103
api_name='tensorflow',
@@ -116,7 +116,7 @@ def testClassImportIsAdded(self):
116116
msg='%s not in %s' % (expected_import, str(imports)))
117117

118118
def testConstantIsAdded(self):
119-
imports, _ = create_python_api.get_api_init_text(
119+
imports, _, _ = create_python_api.get_api_init_text(
120120
packages=[create_python_api._DEFAULT_PACKAGE],
121121
output_package='tensorflow',
122122
api_name='tensorflow',
@@ -132,7 +132,7 @@ def testConstantIsAdded(self):
132132
msg='%s not in %s' % (expected, str(imports)))
133133

134134
def testCompatModuleIsAdded(self):
135-
imports, _ = create_python_api.get_api_init_text(
135+
imports, _, _ = create_python_api.get_api_init_text(
136136
packages=[create_python_api._DEFAULT_PACKAGE],
137137
output_package='tensorflow',
138138
api_name='tensorflow',
@@ -144,7 +144,7 @@ def testCompatModuleIsAdded(self):
144144
msg='compat.v1.test not in %s' % str(imports.keys()))
145145

146146
def testNestedCompatModulesAreAdded(self):
147-
imports, _ = create_python_api.get_api_init_text(
147+
imports, _, _ = create_python_api.get_api_init_text(
148148
packages=[create_python_api._DEFAULT_PACKAGE],
149149
output_package='tensorflow',
150150
api_name='tensorflow',

tensorflow/virtual_root_template_v1.__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,4 @@ def _forward_module(old_name):
132132
except NameError:
133133
pass
134134

135-
# Manually patch keras and estimator so tf.keras and tf.estimator work
136-
keras = _sys.modules["tensorflow.keras"]
137-
if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"]
138135
# LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss)

tensorflow/virtual_root_template_v2.__init__.py

-10
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,4 @@ def _forward_module(old_name):
126126
except NameError:
127127
pass
128128

129-
# TODO(mihaimaruseac): Revisit all of this once we release 2.1
130-
# Manually patch keras and estimator so tf.keras and tf.estimator work
131-
keras = _sys.modules["tensorflow.keras"]
132-
if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"]
133-
# Also import module aliases
134-
try:
135-
from tensorflow_core import losses, metrics, initializers, optimizers
136-
except ImportError:
137-
pass
138-
139129
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)

0 commit comments

Comments
 (0)