|
6 | 6 | import re
|
7 | 7 | import fnmatch
|
8 | 8 | import warnings
|
| 9 | +from copy import copy |
9 | 10 | from collections import OrderedDict
|
10 | 11 |
|
11 | 12 | import numpy as np
|
12 | 13 |
|
| 14 | +from larray.core.abstractbases import ABCArray |
13 | 15 | from larray.core.metadata import Metadata
|
14 | 16 | from larray.core.group import Group
|
15 |
| -from larray.core.axis import Axis |
| 17 | +from larray.core.axis import Axis, AxisCollection |
16 | 18 | from larray.core.constants import nan
|
17 | 19 | from larray.core.array import Array, get_axes, ndtest, zeros, zeros_like, sequence, asarray
|
18 | 20 | from larray.util.misc import float_error_handler_factory, is_interactive_interpreter, renamed_to, inverseop
|
@@ -95,7 +97,7 @@ def __init__(self, *args, **kwargs):
|
95 | 97 |
|
96 | 98 | if len(args) == 1:
|
97 | 99 | a0 = args[0]
|
98 |
| - if isinstance(a0, str): |
| 100 | + if isinstance(a0, basestring): |
99 | 101 | # assume a0 is a filename
|
100 | 102 | self.load(a0)
|
101 | 103 | else:
|
@@ -1475,6 +1477,132 @@ def display(k, v, is_metadata=False):
|
1475 | 1477 | return res
|
1476 | 1478 |
|
1477 | 1479 |
|
| 1480 | +class ArrayDef(ABCArray): |
| 1481 | + def __init__(self, axes): |
| 1482 | + if not all([isinstance(axis, (basestring, Axis)) for axis in axes]): |
| 1483 | + raise TypeError('ArrayDef only accepts string or Axis objects') |
| 1484 | + self.axes = axes |
| 1485 | + |
| 1486 | + |
| 1487 | +class ConstrainedSession(Session): |
| 1488 | + """ |
| 1489 | + Examples |
| 1490 | + -------- |
| 1491 | + Content of file 'model_variables.py' |
| 1492 | +
|
| 1493 | + >>> # ==== MODEL VARIABLES ==== |
| 1494 | + >>> class ModelVariables(ConstrainedSession): |
| 1495 | + ... FIRST_OBS_YEAR = int |
| 1496 | + ... FIRST_PROJ_YEAR = int |
| 1497 | + ... LAST_PROJ_YEAR = int |
| 1498 | + ... AGE = Axis |
| 1499 | + ... GENDER = Axis |
| 1500 | + ... TIME = Axis |
| 1501 | + ... G_CHILDREN = Group |
| 1502 | + ... G_ADULTS = Group |
| 1503 | + ... G_OBS_YEARS = Group |
| 1504 | + ... G_PROJ_YEARS = Group |
| 1505 | + ... population = ArrayDef(('AGE', 'GENDER', 'TIME')) |
| 1506 | + ... births = ArrayDef(('AGE', 'GENDER', 'TIME')) |
| 1507 | + ... deaths = ArrayDef(('AGE', 'GENDER', 'TIME')) |
| 1508 | +
|
| 1509 | + Content of file 'model.py' |
| 1510 | +
|
| 1511 | + >>> def run_model(variant_name, first_proj_year, last_proj_year): |
| 1512 | + ... # create an instance of the ModelVariables class |
| 1513 | + ... m = ModelVariables() |
| 1514 | + ... # ==== setup variables ==== |
| 1515 | + ... # set scalars |
| 1516 | + ... m.FIRST_OBS_YEAR = 1991 |
| 1517 | + ... m.FIRST_PROJ_YEAR = first_proj_year |
| 1518 | + ... m.LAST_PROJ_YEAR = last_proj_year |
| 1519 | + ... # set axes |
| 1520 | + ... m.AGE = Axis('age=0..120') |
| 1521 | + ... m.GENDER = Axis('gender=male,female') |
| 1522 | + ... m.TIME = Axis('time={}..{}'.format(m.FIRST_OBS_YEAR, m.LAST_PROJ_YEAR)) |
| 1523 | + ... # set groups |
| 1524 | + ... m.G_CHILDREN = m.AGE[:17] |
| 1525 | + ... m.G_ADULTS = m.AGE[18:] |
| 1526 | + ... m.G_OBS_YEARS = m.TIME[:m.FIRST_PROJ_YEAR-1] |
| 1527 | + ... m.G_PROJ_YEARS = m.TIME[m.FIRST_PROJ_YEAR:] |
| 1528 | + ... # set arrays |
| 1529 | + ... m.population = zeros((m.AGE, m.GENDER, m.TIME)) |
| 1530 | + ... m.births = zeros((m.AGE, m.GENDER, m.TIME)) |
| 1531 | + ... m.deaths = zeros((m.AGE, m.GENDER, m.TIME)) |
| 1532 | + ... # ==== model ==== |
| 1533 | + ... # some code here |
| 1534 | + ... # ... |
| 1535 | + ... # ==== output ==== |
| 1536 | + ... # save all variables in an HDF5 file |
| 1537 | + ... m.save('{variant_name}.h5', display=True) |
| 1538 | +
|
| 1539 | + Content of file 'main.py' |
| 1540 | +
|
| 1541 | + >>> run_model('proj_2020_2070', first_proj_year=2020, last_proj_year=2070) |
| 1542 | + dumping FIRST_OBS_YEAR ... done |
| 1543 | + dumping FIRST_PROJ_YEAR ... done |
| 1544 | + dumping LAST_PROJ_YEAR ... done |
| 1545 | + dumping AGE ... done |
| 1546 | + dumping GENDER ... done |
| 1547 | + dumping TIME ... done |
| 1548 | + dumping G_CHILDREN ... done |
| 1549 | + dumping G_ADULTS ... done |
| 1550 | + dumping G_OBS_YEARS ... done |
| 1551 | + dumping G_PROJ_YEARS ... done |
| 1552 | + dumping population ... done |
| 1553 | + dumping births ... done |
| 1554 | + dumping deaths ... done |
| 1555 | + """ |
| 1556 | + def __setitem__(self, key, value): |
| 1557 | + self._check_key_value(key, value) |
| 1558 | + |
| 1559 | + # we need to keep the attribute in sync (initially to mask the class attribute) |
| 1560 | + object.__setattr__(self, key, value) |
| 1561 | + self._objects[key] = value |
| 1562 | + |
| 1563 | + def __setattr__(self, key, value): |
| 1564 | + if key != 'meta': |
| 1565 | + self._check_key_value(key, value) |
| 1566 | + |
| 1567 | + # update the real attribute |
| 1568 | + object.__setattr__(self, key, value) |
| 1569 | + # update self._objects |
| 1570 | + Session.__setattr__(self, key, value) |
| 1571 | + |
| 1572 | + def _check_key_value(self, key, value): |
| 1573 | + cls = self.__class__ |
| 1574 | + attr_def = getattr(cls, key, None) |
| 1575 | + if attr_def is None: |
| 1576 | + warnings.warn("'{}' is not declared in '{}'".format(key, self.__class__.__name__), stacklevel=2) |
| 1577 | + else: |
| 1578 | + attr_type = Array if isinstance(attr_def, ArrayDef) else attr_def |
| 1579 | + if not isinstance(value, attr_type): |
| 1580 | + raise TypeError("Expected object of type '{}'. Got object of type '{}'." |
| 1581 | + .format(attr_type.__name__, value.__class__.__name__)) |
| 1582 | + if isinstance(attr_def, ArrayDef): |
| 1583 | + def get_axis(axis): |
| 1584 | + if isinstance(axis, basestring): |
| 1585 | + try: |
| 1586 | + axis = getattr(self, axis) |
| 1587 | + except AttributeError: |
| 1588 | + raise ValueError("Axis '{}' not defined in '{}'".format(axis, self.__class__.__name__)) |
| 1589 | + return axis |
| 1590 | + |
| 1591 | + defined_axes = AxisCollection([get_axis(axis) for axis in attr_def.axes]) |
| 1592 | + try: |
| 1593 | + defined_axes.check_compatible(value.axes) |
| 1594 | + except ValueError as error: |
| 1595 | + msg = str(error).replace("incompatible axes:", "incompatible axes for array '{}':".format(key))\ |
| 1596 | + .replace("vs", "was declared as") |
| 1597 | + raise ValueError(msg) |
| 1598 | + |
| 1599 | + def copy(self): |
| 1600 | + instance = self.__class__() |
| 1601 | + for key, value in self.items(): |
| 1602 | + instance[key] = copy(value) |
| 1603 | + return instance |
| 1604 | + |
| 1605 | + |
1478 | 1606 | def _exclude_private_vars(vars_dict):
|
1479 | 1607 | return {k: v for k, v in vars_dict.items() if not k.startswith('_')}
|
1480 | 1608 |
|
|
0 commit comments