-
Notifications
You must be signed in to change notification settings - Fork 421
Refactor botorch/sampling/pathwise and add support for product kernels #2838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2838 +/- ##
===========================================
- Coverage 100.00% 99.70% -0.30%
===========================================
Files 211 214 +3
Lines 19320 19747 +427
===========================================
+ Hits 19320 19689 +369
- Misses 0 58 +58 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Thanks @seashoo for the PR - this is a big one! It'll take me a bit of time to review this in detail, I plan to do a first higher-level pass this week.
What exactly does this mean? |
Hi @Balandat, Thanks for the response! We've included a more detailed Project Overview section in the pull request description to clarify our validation approach. Specifically, we utilized the existing unit test files, which cover prior, updates, and posterior sampling, and ensured that all tests passed as part of this rebase. While these tests are comprehensive, we welcome any additional guidance you might have on further validating the code's robustness. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went over the main code in the PR in detail; overall this looks great, thanks for the effort and patching some gaps (e.g. _gaussian_update_ModelListGP
). I have not reviewed the testing code in detail, but can do that after the next pass.
The key things to address are:
- Some additions in the patch file were not included here - curious to understand why (and if this was an oversight let's add them in - I pointed out which ones).
- Currently the tests still have some coverage gaps based on the codecov report here. Please add some test cases to also cover the currently uncovered lines.
r""" | ||
.. [rahimi2007random] | ||
A. Rahimi and B. Recht. Random features for large-scale kernel machines. | ||
Advances in Neural Information Processing Systems 20 (2007). | ||
|
||
.. [sutherland2015error] | ||
D. J. Sutherland and J. Schneider. On the error of random Fourier features. | ||
arXiv preprint arXiv:1506.02785 (2015). | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove these references?
# generators. | ||
# It defines a callable that takes a kernel and dimension parameters and returns a | ||
# KernelFeatureMap. | ||
TKernelFeatureMapGenerator = Callable[[kernels.Kernel, int, int], KernelFeatureMap] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding these comments, they're very helpful
num_inputs: int, | ||
num_outputs: int, | ||
num_random_features: int = 1024, | ||
num_ambient_inputs: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We general use this PEP 604 style type definition - let's update this throughout the code here.
num_ambient_inputs: Optional[int] = None, | |
num_ambient_inputs: int | None = None, |
# IMPLEMENTATION NOTE: This function serves as the main entry point for generating | ||
# feature maps from kernels. It uses the dispatcher to call the appropriate handler | ||
# based on the kernel type. The function has been updated from the original | ||
# implementation | ||
# to use more descriptive parameter names (num_ambient_inputs instead of num_inputs, | ||
# and num_random_features instead of num_outputs) to better reflect their purpose. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please keep the docstring format the same as in the previous code (this is also contained in product_kernel_diff.txt). This applies throughout the code
device=kernel.device, | ||
dtype=kernel.dtype, | ||
) | ||
output_transforms = [transforms.SineCosineTransform(constant)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be the same as the behavior in the patch file? Relevant code snippet of what I think this should look like:
output_transforms = [transforms.ConstantMulTransform(constant)]
if cosine_only:
bias = 2 * pi * torch.rand(num_random_features, **tkwargs)
num_raw_features = num_random_features
output_transforms.append(transforms.CosineTransform())
elif num_random_features % 2:
raise UnsupportedError(
f"Expected an even number of random features, but {num_random_features=}."
)
else:
bias = None
num_raw_features = num_random_features // 2
output_transforms.append(transforms.SineCosineTransform())
noise_values = torch.randn_like(sample_values).unsqueeze(-1) | ||
noise_values = noise_covariance.cholesky() @ noise_values | ||
sample_values = sample_values + noise_values.squeeze(-1) | ||
# Generate noise values with correct shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for fixing this
task_index = ( | ||
num_inputs + model._task_feature | ||
if model._task_feature < 0 | ||
else model._task_feature | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to do this so we can always assume that it's positive and we don't have to do this custom handling. But I think this is ok for now as is, that change is beyond the scope of this PR.
task_index = ( | |
num_inputs + model._task_feature | |
if model._task_feature < 0 | |
else model._task_feature | |
) | |
# TODO: Changed `MultiTaskGP` to normalize the task feature in its constructor. | |
task_index = ( | |
num_inputs + model._task_feature | |
if model._task_feature < 0 | |
else model._task_feature | |
) |
|
||
|
||
@GaussianUpdate.register(ModelListGP, LikelihoodList) | ||
def _gaussian_update_ModelListGP( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is completely new, right? Very nice!
class SineCosineTransform(TensorTransform): | ||
r"""A transform that returns concatenated sine and cosine features.""" | ||
|
||
def __init__(self, scale: Optional[Tensor] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, scale: Optional[Tensor] = None): | |
def __init__(self, scale: Tensor | None = None): |
# Removed unused imports | ||
# from botorch.sampling.pathwise.utils.transforms import ( | ||
# ChainedTransform, | ||
# FeatureSelector | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if unused let's delete them outright - applies throughout the code
# Removed unused imports | |
# from botorch.sampling.pathwise.utils.transforms import ( | |
# ChainedTransform, | |
# FeatureSelector | |
# ) |
Motivation
Hi! I'm Sahran Ashoor, an undergraduate research assistant working for the Uncertainty Quantification Lab at the University of Houston. I work under Dr. Ruda Zhang and Taiwo Adebiyi, both of whom having already spoken with Max Balandat regarding incorporating a rebase of botorch/sampling/pathwise (Largely written by James. T. Wilson). The changes included in this pull request are my best attempt at faithfully completing the change logs I was provided (product_kernel_diff.txt).
Have you read the Contributing Guidelines on pull requests?
Yes!
Project Overview
The primary goal was to make the original codebase by Wilson compatible with the latest BoTorch version. To achieve this, we used the original source codes and test suites, which initially revealed several incompatibility issues. Our main contribution involved carefully rebasing Wilson's code while preserving the logic for pathwise sampling. Importantly, all changes were confined to the botorch/sampling/pathwise directory to ensure a seamless integration, passing both local pathwise test suites and BoTorch's global test suites via GitHub workflows.
In terms of code logic, we relied on Wilson's unit tests for prior, updates, and posterior sampling, which we believe are sufficient to validate the correctness of the implementation. However, we welcome your feedback on this approach, and would appreciate any suggestions for additional tests or example scripts to further confirm the robustness of the changes. We are open to collaborating further on this effort.
Test Plan
(Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!)
The entirety of the testing suite was ran through pytest. Through additional verification we've found that the logic may be offset, but we're hoping to work with you all and further validate these changes under the discretion of Dr. Zhang. Expect further communications directly from my lab that will provide more insight into the rebase.
Related PRs
(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.)
N/A