Skip to content

Commit f65fae2

Browse files
committed
fix #832 : implemented ConstrainedSession class (based on preliminary code written by gdementen)
1 parent c115324 commit f65fae2

File tree

5 files changed

+586
-146
lines changed

5 files changed

+586
-146
lines changed

doc/source/api.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,6 @@ Modifying
790790

791791
Session.add
792792
Session.update
793-
Session.get
794793
Session.apply
795794
Session.transpose
796795

@@ -816,6 +815,22 @@ Load/Save
816815
Session.to_hdf
817816
Session.to_pickle
818817

818+
ArrayDef
819+
========
820+
821+
.. autosummary::
822+
:toctree: _generated/
823+
824+
ArrayDef
825+
826+
ConstrainedSession
827+
==================
828+
829+
.. autosummary::
830+
:toctree: _generated/
831+
832+
ConstrainedSession
833+
819834
.. _api-editor:
820835

821836
Editor

larray/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
full_like, sequence, labels_array, ndtest, asarray, identity, diag,
1010
eye, all, any, sum, prod, cumsum, cumprod, min, max, mean, ptp, var,
1111
std, median, percentile, stack, zip_array_values, zip_array_items)
12-
from larray.core.session import Session, local_arrays, global_arrays, arrays
12+
from larray.core.session import Session, ConstrainedSession, ArrayDef, local_arrays, global_arrays, arrays
1313
from larray.core.constants import nan, inf, pi, e, euler_gamma
1414
from larray.core.metadata import Metadata
1515
from larray.core.ufuncs import wrap_elementwise_array_func, maximum, minimum, where
@@ -58,7 +58,7 @@
5858
'all', 'any', 'sum', 'prod', 'cumsum', 'cumprod', 'min', 'max', 'mean', 'ptp', 'var', 'std',
5959
'median', 'percentile', 'stack', 'zip_array_values', 'zip_array_items',
6060
# session
61-
'Session', 'local_arrays', 'global_arrays', 'arrays',
61+
'Session', 'ConstrainedSession', 'ArrayDef', 'local_arrays', 'global_arrays', 'arrays',
6262
# constants
6363
'nan', 'inf', 'pi', 'e', 'euler_gamma',
6464
# metadata

larray/core/session.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import re
77
import fnmatch
88
import warnings
9+
from copy import copy
910
from collections import OrderedDict
1011

1112
import numpy as np
1213

14+
from larray.core.abstractbases import ABCArray
1315
from larray.core.metadata import Metadata
1416
from larray.core.group import Group
15-
from larray.core.axis import Axis
17+
from larray.core.axis import Axis, AxisCollection
1618
from larray.core.constants import nan
1719
from larray.core.array import Array, get_axes, ndtest, zeros, zeros_like, sequence, asarray
1820
from larray.util.misc import float_error_handler_factory, is_interactive_interpreter, renamed_to, inverseop
@@ -95,7 +97,7 @@ def __init__(self, *args, **kwargs):
9597

9698
if len(args) == 1:
9799
a0 = args[0]
98-
if isinstance(a0, str):
100+
if isinstance(a0, basestring):
99101
# assume a0 is a filename
100102
self.load(a0)
101103
else:
@@ -1475,6 +1477,132 @@ def display(k, v, is_metadata=False):
14751477
return res
14761478

14771479

1480+
class ArrayDef(ABCArray):
1481+
def __init__(self, axes):
1482+
if not all([isinstance(axis, (basestring, Axis)) for axis in axes]):
1483+
raise TypeError('ArrayDef only accepts string or Axis objects')
1484+
self.axes = axes
1485+
1486+
1487+
class ConstrainedSession(Session):
1488+
"""
1489+
Examples
1490+
--------
1491+
Content of file 'model_variables.py'
1492+
1493+
>>> # ==== MODEL VARIABLES ====
1494+
>>> class ModelVariables(ConstrainedSession):
1495+
... FIRST_OBS_YEAR = int
1496+
... FIRST_PROJ_YEAR = int
1497+
... LAST_PROJ_YEAR = int
1498+
... AGE = Axis
1499+
... GENDER = Axis
1500+
... TIME = Axis
1501+
... G_CHILDREN = Group
1502+
... G_ADULTS = Group
1503+
... G_OBS_YEARS = Group
1504+
... G_PROJ_YEARS = Group
1505+
... population = ArrayDef(('AGE', 'GENDER', 'TIME'))
1506+
... births = ArrayDef(('AGE', 'GENDER', 'TIME'))
1507+
... deaths = ArrayDef(('AGE', 'GENDER', 'TIME'))
1508+
1509+
Content of file 'model.py'
1510+
1511+
>>> def run_model(variant_name, first_proj_year, last_proj_year):
1512+
... # create an instance of the ModelVariables class
1513+
... m = ModelVariables()
1514+
... # ==== setup variables ====
1515+
... # set scalars
1516+
... m.FIRST_OBS_YEAR = 1991
1517+
... m.FIRST_PROJ_YEAR = first_proj_year
1518+
... m.LAST_PROJ_YEAR = last_proj_year
1519+
... # set axes
1520+
... m.AGE = Axis('age=0..120')
1521+
... m.GENDER = Axis('gender=male,female')
1522+
... m.TIME = Axis('time={}..{}'.format(m.FIRST_OBS_YEAR, m.LAST_PROJ_YEAR))
1523+
... # set groups
1524+
... m.G_CHILDREN = m.AGE[:17]
1525+
... m.G_ADULTS = m.AGE[18:]
1526+
... m.G_OBS_YEARS = m.TIME[:m.FIRST_PROJ_YEAR-1]
1527+
... m.G_PROJ_YEARS = m.TIME[m.FIRST_PROJ_YEAR:]
1528+
... # set arrays
1529+
... m.population = zeros((m.AGE, m.GENDER, m.TIME))
1530+
... m.births = zeros((m.AGE, m.GENDER, m.TIME))
1531+
... m.deaths = zeros((m.AGE, m.GENDER, m.TIME))
1532+
... # ==== model ====
1533+
... # some code here
1534+
... # ...
1535+
... # ==== output ====
1536+
... # save all variables in an HDF5 file
1537+
... m.save('{variant_name}.h5', display=True)
1538+
1539+
Content of file 'main.py'
1540+
1541+
>>> run_model('proj_2020_2070', first_proj_year=2020, last_proj_year=2070)
1542+
dumping FIRST_OBS_YEAR ... done
1543+
dumping FIRST_PROJ_YEAR ... done
1544+
dumping LAST_PROJ_YEAR ... done
1545+
dumping AGE ... done
1546+
dumping GENDER ... done
1547+
dumping TIME ... done
1548+
dumping G_CHILDREN ... done
1549+
dumping G_ADULTS ... done
1550+
dumping G_OBS_YEARS ... done
1551+
dumping G_PROJ_YEARS ... done
1552+
dumping population ... done
1553+
dumping births ... done
1554+
dumping deaths ... done
1555+
"""
1556+
def __setitem__(self, key, value):
1557+
self._check_key_value(key, value)
1558+
1559+
# we need to keep the attribute in sync (initially to mask the class attribute)
1560+
object.__setattr__(self, key, value)
1561+
self._objects[key] = value
1562+
1563+
def __setattr__(self, key, value):
1564+
if key != 'meta':
1565+
self._check_key_value(key, value)
1566+
1567+
# update the real attribute
1568+
object.__setattr__(self, key, value)
1569+
# update self._objects
1570+
Session.__setattr__(self, key, value)
1571+
1572+
def _check_key_value(self, key, value):
1573+
cls = self.__class__
1574+
attr_def = getattr(cls, key, None)
1575+
if attr_def is None:
1576+
warnings.warn("'{}' is not declared in '{}'".format(key, self.__class__.__name__), stacklevel=2)
1577+
else:
1578+
attr_type = Array if isinstance(attr_def, ArrayDef) else attr_def
1579+
if not isinstance(value, attr_type):
1580+
raise TypeError("Expected object of type '{}'. Got object of type '{}'."
1581+
.format(attr_type.__name__, value.__class__.__name__))
1582+
if isinstance(attr_def, ArrayDef):
1583+
def get_axis(axis):
1584+
if isinstance(axis, basestring):
1585+
try:
1586+
axis = getattr(self, axis)
1587+
except AttributeError:
1588+
raise ValueError("Axis '{}' not defined in '{}'".format(axis, self.__class__.__name__))
1589+
return axis
1590+
1591+
defined_axes = AxisCollection([get_axis(axis) for axis in attr_def.axes])
1592+
try:
1593+
defined_axes.check_compatible(value.axes)
1594+
except ValueError as error:
1595+
msg = str(error).replace("incompatible axes:", "incompatible axes for array '{}':".format(key))\
1596+
.replace("vs", "was declared as")
1597+
raise ValueError(msg)
1598+
1599+
def copy(self):
1600+
instance = self.__class__()
1601+
for key, value in self.items():
1602+
instance[key] = copy(value)
1603+
return instance
1604+
1605+
14781606
def _exclude_private_vars(vars_dict):
14791607
return {k: v for k, v in vars_dict.items() if not k.startswith('_')}
14801608

larray/tests/data/test_session.h5

360 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)