Skip to content

Parameterized domain types --- a mechanism for sub-workflow templates #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
-r base.in
numpy
pytest
4 changes: 3 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:a035a60fcbac4cd7bf595dbd81ee7994505d4a95
# SHA1:3720d9b18830e4fcedb827ec36f6808035c1ea2c
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand All @@ -10,6 +10,8 @@ exceptiongroup==1.1.1
# via pytest
iniconfig==2.0.0
# via pytest
numpy==1.24.4
# via -r test.in
pluggy==1.0.0
# via pytest
pytest==7.3.1
Expand Down
1 change: 1 addition & 0 deletions src/sciline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
__version__ = "0.0.0"

from .container import Container, UnsatisfiedRequirement, make_container
from .domain import parametrized_domain_type
2 changes: 2 additions & 0 deletions src/sciline/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

import typing
from functools import wraps
from typing import Callable, List, Type, TypeVar, Union
Expand Down
33 changes: 33 additions & 0 deletions src/sciline/domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Dict, NewType


def parametrized_domain_type(name: str, base: type) -> type:
"""
Return a type-factory for parametrized domain types.

The types return by the factory are created using typing.NewType. The returned
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The types return by the factory are created using typing.NewType. The returned
The types returned by the factory are created using typing.NewType. The returned

factory is used similarly to a Generic, but note that the factory itself should
not be used for annotations.

Parameters
----------
name:
The name of the type. This is used as a prefix for the names of the types
returned by the factory.
base:
The base type of the types returned by the factory.
"""

class Factory:
_subtypes: Dict[str, type] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So parametrized_domain_type must only be called once per name? That is, does using Monitor = parametrized_domain_type("Monitor", DataArray) in two different files lead to distinct types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell this is the same for typing.NewType?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I just wanted to make sure. This looks like the right thing to do.


def __class_getitem__(cls, tp: type) -> type:
key = f'{name}_{tp.__name__}'
if (t := cls._subtypes.get(key)) is None:
t = NewType(key, base)
cls._subtypes[key] = t
return t

return Factory
105 changes: 105 additions & 0 deletions tests/complex_workflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from dataclasses import dataclass
from typing import Callable, List, NewType

import dask
import numpy as np

import sciline as sl

# We use dask with a single thread, to ensure that call counting below is correct.
dask.config.set(scheduler='synchronous')


@dataclass
class RawData:
data: np.ndarray
monitor1: float
monitor2: float


SampleRun = NewType('SampleRun', int)
BackgroundRun = NewType('BackgroundRun', int)
DetectorMask = NewType('DetectorMask', np.ndarray)
DirectBeam = NewType('DirectBeam', np.ndarray)
SolidAngle = NewType('SolidAngle', np.ndarray)
Raw = sl.parametrized_domain_type('Raw', RawData)
Masked = sl.parametrized_domain_type('Masked', np.ndarray)
IncidentMonitor = sl.parametrized_domain_type('IncidentMonitor', float)
TransmissionMonitor = sl.parametrized_domain_type('TransmissionMonitor', float)
TransmissionFraction = sl.parametrized_domain_type('TransmissionFraction', float)
IofQ = sl.parametrized_domain_type('IofQ', np.ndarray)
BackgroundSubtractedIofQ = NewType('BackgroundSubtractedIofQ', np.ndarray)


def reduction_factory(tp: type) -> List[Callable]:
def incident_monitor(x: Raw[tp]) -> IncidentMonitor[tp]:
Comment on lines +36 to +37
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy does not like this: tp is a variable (holding a type) and cannot be used for type-checking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could try

def reduction_factory(tp: type) -> List[Callable]:
    T = TypeVar("T", bound=tp)

    def incident_monitor(x: Raw[T]) -> IncidentMonitor[T]:

But this probably doesn't work either.

return IncidentMonitor[tp](x.monitor1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could avoid repeating the return type. The injector should be able to convert the return value to the correct type. (Would likely need modifications to injector package.) But would this mess up type hinting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for run-time, you can just omit it. I don't know if the type-checker needs it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it and neither mypy nor pyre like it when I just return the base type. So I guess we need to explicitly wrap the result.


def transmission_monitor(x: Raw[tp]) -> TransmissionMonitor[tp]:
return TransmissionMonitor[tp](x.monitor2)

def mask_detector(x: Raw[tp], mask: DetectorMask) -> Masked[tp]:
return Masked[tp](x.data * mask)

def transmission(
incident: IncidentMonitor[tp], transmission: TransmissionMonitor[tp]
) -> TransmissionFraction[tp]:
return TransmissionFraction[tp](incident / transmission)

def iofq(
x: Masked[tp],
solid_angle: SolidAngle,
direct_beam: DirectBeam,
transmission: TransmissionFraction[tp],
) -> IofQ[tp]:
return IofQ[tp](x / (solid_angle * direct_beam * transmission))

return [incident_monitor, transmission_monitor, mask_detector, transmission, iofq]


def raw_sample() -> Raw[SampleRun]:
return Raw[SampleRun](RawData(data=np.ones(4), monitor1=1.0, monitor2=2.0))


def raw_background() -> Raw[BackgroundRun]:
return Raw[BackgroundRun](
RawData(data=np.ones(4) * 1.5, monitor1=1.0, monitor2=4.0)
)


def detector_mask() -> DetectorMask:
return DetectorMask(np.array([1, 1, 0, 1]))


def solid_angle() -> SolidAngle:
return SolidAngle(np.array([1.0, 0.5, 0.25, 0.125]))


def direct_beam() -> DirectBeam:
return DirectBeam(np.array(1 / 1.5))


def subtract_background(
sample: IofQ[SampleRun], background: IofQ[BackgroundRun]
) -> BackgroundSubtractedIofQ:
return BackgroundSubtractedIofQ(sample - background)


def test_reduction_workflow():
container = sl.make_container(
[
raw_sample,
raw_background,
detector_mask,
solid_angle,
direct_beam,
subtract_background,
]
+ reduction_factory(SampleRun)
+ reduction_factory(BackgroundRun)
)
assert np.array_equal(container.get(IofQ[SampleRun]), [3, 6, 0, 24])
assert np.array_equal(container.get(IofQ[BackgroundRun]), [9, 18, 0, 72])
assert np.array_equal(container.get(BackgroundSubtractedIofQ), [-6, -12, 0, -48])
85 changes: 77 additions & 8 deletions tests/container_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Callable, List, NewType

import dask
import pytest

Expand All @@ -9,26 +11,26 @@
dask.config.set(scheduler='synchronous')


def f(x: int) -> float:
def int_to_float(x: int) -> float:
return 0.5 * x


def g() -> int:
def make_int() -> int:
return 3


def h(x: int, y: float) -> str:
def int_float_to_str(x: int, y: float) -> str:
return f"{x};{y}"


def test_make_container_sets_up_working_container():
container = sl.make_container([f, g])
container = sl.make_container([int_to_float, make_int])
assert container.get(float) == 1.5
assert container.get(int) == 3


def test_make_container_does_not_autobind():
container = sl.make_container([f])
container = sl.make_container([int_to_float])
with pytest.raises(sl.UnsatisfiedRequirement):
container.get(float)

Expand All @@ -41,13 +43,15 @@ def provide_int() -> int:
ncall += 1
return 3

container = sl.make_container([f, provide_int, h], lazy=False)
container = sl.make_container(
[int_to_float, provide_int, int_float_to_str], lazy=False
)
assert container.get(str) == "3;1.5"
assert ncall == 1


def test_make_container_lazy_returns_task_that_computes_result():
container = sl.make_container([f, g], lazy=True)
container = sl.make_container([int_to_float, make_int], lazy=True)
task = container.get(float)
assert hasattr(task, 'compute')
assert task.compute() == 1.5
Expand All @@ -61,8 +65,73 @@ def provide_int() -> int:
ncall += 1
return 3

container = sl.make_container([f, provide_int, h], lazy=True)
container = sl.make_container(
[int_to_float, provide_int, int_float_to_str], lazy=True
)
task1 = container.get(float)
task2 = container.get(str)
assert dask.compute(task1, task2) == (1.5, '3;1.5')
assert ncall == 1


def test_make_container_with_subgraph_template():
ncall = 0

def provide_int() -> int:
nonlocal ncall
ncall += 1
return 3

Float = sl.parametrized_domain_type('Float', float)
Str = sl.parametrized_domain_type('Str', str)

def child(tp: type) -> List[Callable]:
def int_float_to_str(x: int, y: Float[tp]) -> Str[tp]:
return Str[tp](f"{x};{y}")

return [int_float_to_str]

Run1 = NewType('Run1', int)
Run2 = NewType('Run2', int)
Result = NewType('Result', str)

def float1() -> Float[Run1]:
return Float[Run1](1.5)

def float2() -> Float[Run2]:
return Float[Run2](2.5)

def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result:
return Result(f"{s1};{s2}")

container = sl.make_container(
[provide_int, float1, float2, use_strings] + child(Run1) + child(Run2),
lazy=True,
)
assert container.get(Result).compute() == "3;1.5;3;2.5"
assert ncall == 1


Str = sl.parametrized_domain_type('Str', str)


def subworkflow(tp: type) -> List[Callable]:
def f(x: tp) -> Str[tp]:
return Str[tp](f'{x}')

return [f]


def test_container_from_templated():
def make_float() -> float:
return 1.5

def combine(x: Str[int], y: Str[float]) -> str:
return f"{x};{y}"

container = sl.make_container(
[make_int, make_float, combine] + subworkflow(int) + subworkflow(float)
)
assert container.get(Str[int]) == '3'
assert container.get(Str[float]) == '1.5'
assert container.get(str) == '3;1.5'