Skip to content

Commit

Permalink
Make pywrap rules replicate final artifacts structure to ensure backw…
Browse files Browse the repository at this point in the history
…ard compatibility with users who directly use TensorFlow's shared object files.

PiperOrigin-RevId: 710637533
  • Loading branch information
vam-google authored and copybara-github committed Dec 30, 2024
1 parent 4b22bd9 commit 41dda56
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 242 deletions.
12 changes: 8 additions & 4 deletions third_party/py/rules_pywrap/pybind_extension.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@ def __update_globals(pywrap_m):

def __try_import():
imports_paths = [] # template_val
exceptions = []
last_exception = None
for import_path in imports_paths:
try:
pywrap_m = __import__(import_path, fromlist=["*"])
__update_globals(pywrap_m)
return
except ImportError:
# try another packge if there are any left
except ImportError as e:
exceptions.append(str(e))
last_exception = e
pass

raise RuntimeError(
"Could not detect original test/binary location, import paths tried: %s" % imports_paths)
raise RuntimeError(f"""
Could not import original test/binary location, import paths tried: {imports_paths}.
Previous exceptions: {exceptions}""", last_exception)

__try_import()
13 changes: 1 addition & 12 deletions third_party/py/rules_pywrap/pywrap.default.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ def pybind_extension(

# To patch top-level deps lists in sophisticated cases
pywrap_ignored_deps_filter = ["@pybind11", "@pybind11//:pybind11"],
pywrap_private_deps_filter = [
"@pybind11_abseil//pybind11_abseil:absl_casters",
"@pybind11_abseil//pybind11_abseil:import_status_module",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],
pytype_srcs = None, # alias for data
hdrs = [], # merge into sources
pytype_deps = None, # ignore?
Expand All @@ -53,7 +47,6 @@ def pybind_extension(
pytype_deps,
]

private_deps_filter_dict = {k: None for k in pywrap_private_deps_filter}
ignored_deps_filter_dict = {k: None for k in pywrap_ignored_deps_filter}

actual_srcs = srcs + hdrs
Expand All @@ -67,13 +60,10 @@ def pybind_extension(
actual_private_deps = []
actual_default_deps = ["@pybind11//:pybind11"]

if type(deps) == list:
if not deps or type(deps) == list:
for dep in deps:
if dep in ignored_deps_filter_dict:
continue
if dep in private_deps_filter_dict:
actual_private_deps.append(dep)
continue
actual_deps.append(dep)
else:
actual_deps = deps
Expand All @@ -83,7 +73,6 @@ def pybind_extension(
name = name,
deps = actual_deps,
srcs = actual_srcs,
private_deps = actual_private_deps,
visibility = visibility,
win_def_file = win_def_file,
testonly = testonly,
Expand Down
Loading

0 comments on commit 41dda56

Please sign in to comment.