Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions mesmerize_core/caiman_extensions/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import wraps
from typing import Union
from typing import Optional, Union, Protocol
from uuid import UUID

import pandas as pd

from mesmerize_core.utils import wrapsmethod
from mesmerize_core.caiman_extensions._batch_exceptions import (
BatchItemNotRunError,
BatchItemUnsuccessfulError,
Expand All @@ -10,9 +12,14 @@
)


def validate(algo: str = None):
class SeriesExtensions(Protocol):
"""Common interface for series accessors to help with type hinting"""
_series: pd.Series


def validate(algo: Optional[str] = None):
def dec(func):
@wraps(func)
@wrapsmethod(func)
def wrapper(self, *args, **kwargs):
if self._series["outputs"] is None:
raise BatchItemNotRunError("Item has not been run")
Expand All @@ -38,7 +45,7 @@ def wrapper(self, *args, **kwargs):
def _verify_and_lock_batch_file(func):
"""Acquires lock and ensures batch file has the same items as current df before calling wrapped function"""

@wraps(func)
@wrapsmethod(func)
def wrapper(instance, *args, **kwargs):
with instance._batch_lock:
disk_df = instance.reload_from_disk()
Expand All @@ -53,7 +60,7 @@ def wrapper(instance, *args, **kwargs):


def _index_parser(func):
@wraps(func)
@wrapsmethod(func)
def _parser(instance, *args, **kwargs):
if "index" in kwargs.keys():
index: Union[int, str, UUID] = kwargs["index"]
Expand Down
235 changes: 130 additions & 105 deletions mesmerize_core/caiman_extensions/cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from functools import wraps
from typing import Union, Optional
import inspect
from typing import Union, Optional, TypeVar

import pandas as pd
import time
import numpy as np
import sys
from caiman.source_extraction.cnmf import CNMF
import re
from sys import getsizeof
import copy

from ..utils import wrapsmethod
from ._utils import SeriesExtensions

def _check_arg_equality(args, cache_args):

# return type of decorated method
R = TypeVar("R")


def _check_arg_equality(args, cache_args) -> bool:
if not type(args) == type(cache_args):
return False
if isinstance(cache_args, np.ndarray):
Expand All @@ -20,42 +25,65 @@ def _check_arg_equality(args, cache_args):
return cache_args == args


def _check_args_equality(args, cache_args):
def _check_args_equality(args, cache_args) -> bool:
if len(args) != len(cache_args):
return False
equality = list()

if isinstance(args, tuple):
for arg, cache_arg in zip(args, cache_args):
equality.append(_check_arg_equality(arg, cache_arg))
if not _check_arg_equality(arg, cache_arg):
return False
else:
for k in args.keys():
equality.append(_check_arg_equality(args[k], cache_args[k]))
return all(equality)
for k, v in args.items():
if k not in cache_args or not _check_arg_equality(v, cache_args[k]):
return False
return True


def _return_wrapper(output, copy_bool):
def _return_wrapper(output: R, copy_bool: bool) -> R:
if copy_bool == True:
return copy.deepcopy(output)
else:
return output


def _get_item_size(item) -> int:
"""Recursively compute size of return value"""
if isinstance(item, np.ndarray):
return item.data.nbytes

elif isinstance(item, (tuple, list)):
size = 0
for entry in item:
size += _get_item_size(entry)
return size

elif isinstance(item, CNMF):
size = 0
for attr in item.estimates.__dict__.values():
size += _get_item_size(attr)
return size

else:
return sys.getsizeof(item)


class Cache:
def __init__(self, cache_size: Optional[Union[int, str]] = None):
self.cache = pd.DataFrame(
data=None,
columns=["uuid", "function", "args", "kwargs", "return_val", "time_stamp"],
columns=["uuid", "function", "kwargs", "return_val", "time_stamp", "added_time", "bytes"],
)
self.set_maxsize(cache_size)

def get_cache(self):
return self.cache

def clear_cache(self):
while len(self.cache.index) != 0:
while len(self.cache) != 0:
self.cache.drop(index=self.cache.index[-1], axis=0, inplace=True)

def set_maxsize(self, max_size: Union[int, str]):
def set_maxsize(self, max_size: Optional[Union[int, str]]):
if max_size is None:
self.storage_type = "RAM"
self.size = 1024**3
Expand All @@ -70,122 +98,119 @@ def set_maxsize(self, max_size: Union[int, str]):
self.size = max_size

def _get_cache_size_bytes(self):
"""Returns in bytes"""
cache_size = 0
for i in range(len(self.cache.index)):
if isinstance(self.cache.iloc[i, 4], np.ndarray):
cache_size += self.cache.iloc[i, 4].data.nbytes
elif isinstance(self.cache.iloc[i, 4], (tuple, list)):
for lists in self.cache.iloc[i, 4]:
for array in lists:
cache_size += array.data.nbytes
elif isinstance(self.cache.iloc[i, 4], CNMF):
sizes = list()
for attr in self.cache.iloc[i, 4].estimates.__dict__.values():
if isinstance(attr, np.ndarray):
sizes.append(attr.data.nbytes)
else:
sizes.append(getsizeof(attr))
else:
cache_size += sys.getsizeof(self.cache.iloc[i, 4])

return cache_size
return self.cache.loc[:, "bytes"].sum()

def use_cache(self, func):
@wraps(func)
def _use_cache(instance, *args, **kwargs):
if "return_copy" in kwargs.keys():
return_copy = kwargs["return_copy"]
else:
return_copy = True

if self.size == 0:
"""
Caching decorator.

Usage:

.. code-block:: python
@cache.use_cache
def my_costly_method(self, *, return_copy=True):
...

return_copy determines whether an entry that is found in the cache is copied before it is returned.
The decorated function *must* take return_copy as a keyword-only paramter, and this will be read by the decorator.
"""
# get default value of return_copy from function signature
params = inspect.signature(func).parameters
return_copy_arg = params.get("return_copy")
if return_copy_arg is None:
raise TypeError("return_copy must be in wrapped function signature when not provided to decorator")
elif return_copy_arg.kind != inspect.Parameter.KEYWORD_ONLY:
raise TypeError("return_copy must be a keyword-only argument")
elif return_copy_arg.default == inspect.Parameter.empty:
return_copy_default = None # unlikely but in this case return_copy would be required
else:
return_copy_default = return_copy_arg.default
assert isinstance(return_copy_default, bool), "return_copy default should be bool"

@wrapsmethod(func)
def _use_cache_wrapper(instance: SeriesExtensions, *args, **kwargs):
# extract return_copy; return_copy is keyword only, so only have to look in kwargs
return_copy = kwargs.get("return_copy", return_copy_default)
if return_copy is None: # no default case
raise TypeError("Must provide a value for return_copy")

if not isinstance(return_copy, bool):
raise TypeError("return_copy must be a bool")

# if we are not storing anything in the cache, just do the function call, no need to search
# still make a copy if copy_bool to make absolutely sure it's not aliasing another object
if self.size == 0:
self.clear_cache()
return _return_wrapper(func(instance, *args, **kwargs), return_copy)

# if cache is empty, will always be a cache miss
if len(self.cache.index) == 0:
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
return _return_wrapper(return_val, copy_bool=return_copy)

return _return_wrapper(func(instance, *args, **kwargs), copy_bool=return_copy)

# iterate through signature and make dict containing arguments to compare, including defaults
args_dict = {}
for i, (param_name, param) in enumerate(params.items()):
if i == 0 or param_name == "return_copy":
continue # skip self/instance and return_copy
elif i-1 < len(args):
args_dict[param_name] = args[i-1]
elif param_name in kwargs:
args_dict[param_name] = kwargs[param_name]
else:
assert param.default != inspect.Parameter.empty, "must have a default argument or there would be a TypeError"
args_dict[param_name] = param.default
# checking to see if there is a cache hit
for i in range(len(self.cache.index)):
for ind, row in self.cache.iterrows():
if (
self.cache.iloc[i, 0] == instance._series["uuid"]
and self.cache.iloc[i, 1] == func.__name__
and _check_args_equality(args, self.cache.iloc[i, 2])
and _check_args_equality(kwargs, self.cache.iloc[i, 3])
row.at["uuid"] == instance._series["uuid"]
and row.at["function"] == func.__name__
and _check_args_equality(args_dict, row.at["kwargs"])
):
self.cache.iloc[i, 5] = time.time()
return_val = self.cache.iloc[i, 4]
return _return_wrapper(self.cache.iloc[i, 4], copy_bool=return_copy)
self.cache.at[ind, "time_stamp"] = time.time() # not supposed to modify row from iterrows
return _return_wrapper(row.at["return_val"], copy_bool=return_copy)

# no cache hit, must check cache limit, and if limit is going to be exceeded...remove least recently used and add new entry
# if memory type is 'ITEMS': drop the least recently used and then add new item
if self.storage_type == "ITEMS" and len(self.cache.index) >= self.size:
return_val = func(instance, *args, **kwargs)
return_val = func(instance, *args, **kwargs)
curr_val_size = _get_item_size(return_val)
if self.storage_type == "RAM" and curr_val_size > self.size:
# too big to fit in the cache, and no point in evicting other items, so just return
return _return_wrapper(return_val, copy_bool=return_copy)

if self.storage_type == "ITEMS" and len(self.cache) >= self.size:
self.cache.drop(
index=self.cache.sort_values(
by=["time_stamp"], ascending=False
).index[-1],
axis=0,
inplace=True,
)
self.cache = self.cache.reset_index(drop=True)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
return _return_wrapper(
self.cache.iloc[len(self.cache.index) - 1, 4], copy_bool=return_copy
)
self.cache.reset_index(drop=True, inplace=True)

# if memory type is 'RAM': add new item and then remove least recently used items until cache is under correct size again
elif self.storage_type == "RAM":
while self._get_cache_size_bytes() > self.size:
while len(self.cache) > 1 and self._get_cache_size_bytes() + curr_val_size > self.size: # can't do anything if it's empty
self.cache.drop(
index=self.cache.sort_values(
by=["time_stamp"], ascending=False
).index[-1],
axis=0,
inplace=True,
)
self.cache = self.cache.reset_index(drop=True)
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
# no matter the storage type if size is not going to be exceeded for either, then item can just be added to cache
else:
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]

self.cache.reset_index(drop=True, inplace=True)

# now ready to add to cache
add_time = time.time()
self.cache.loc[len(self.cache)] = [
instance._series["uuid"],
func.__name__,
args_dict,
return_val,
add_time,
add_time,
curr_val_size,
]
return _return_wrapper(return_val, copy_bool=return_copy)

return _use_cache
return _use_cache_wrapper


def invalidate(self, pre: bool = True, post: bool = True):
"""
Expand All @@ -202,7 +227,7 @@ def invalidate(self, pre: bool = True, post: bool = True):
"""

def _invalidate(func):
@wraps(func)
@wrapsmethod(func)
def __invalidate(instance, *args, **kwargs):
u = instance._series["uuid"]

Expand Down
Loading
Loading