Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 85 additions & 21 deletions markovify/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,50 @@
import bisect
import json
import copy
from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Tuple,
Type,
TypeVar,
Union,
cast,
)


BEGIN = "___BEGIN__"
END = "___END__"


def accumulate(iterable, func=operator.add):
State = Tuple[str, ...]
NextDict = Dict[str, int]
NextCompiled = Tuple[List[str], List[int]]
ModelUncompiled = Dict[State, NextDict]
ModelCompiled = Dict[State, NextCompiled]
Model = Union[ModelUncompiled, ModelCompiled]


T = TypeVar("T")
ChainT = TypeVar("ChainT", bound="Chain")
AccT = TypeVar("AccT")
AccF = Callable[[AccT, AccT], AccT]


class ParamError(Exception):
pass


def cast_not_none(var: Union[T, None]) -> T:
return cast(T, var)


def accumulate(
iterable: Iterable[AccT],
func: AccF[AccT] = operator.add,
) -> Generator[AccT, None, None]:
"""
Cumulative calculations. (Summation, by default.)
Via: https://docs.python.org/3/library/itertools.html#itertools.accumulate
Expand All @@ -21,10 +59,10 @@ def accumulate(iterable, func=operator.add):
yield total


def compile_next(next_dict):
def compile_next(next_dict: NextDict) -> NextCompiled:
words = list(next_dict.keys())
cff = list(accumulate(next_dict.values()))
return [words, cff]
return words, cff


class Chain:
Expand All @@ -33,7 +71,12 @@ class Chain:
For example: Sentences.
"""

def __init__(self, corpus, state_size, model=None):
def __init__(
self,
corpus: Union[Iterable[List[str]], None],
state_size: int,
model: Union[ModelCompiled, None] = None,
):
"""
`corpus`: A list of lists, where each outer list is a "run"
of the process (e.g., a single sentence), and each inner list
Expand All @@ -44,29 +87,33 @@ def __init__(self, corpus, state_size, model=None):
`state_size`: An integer indicating the number of items the model
uses to represent its state. For text generation, 2 or 3 are typical.
"""
if corpus is None and model is None:
raise ParamError("Must provide either `corpus` or `model`.")
self.state_size = state_size
self.model = model or self.build(corpus, self.state_size)
self.model = model or self.build(cast_not_none(corpus), self.state_size)
self.compiled = (len(self.model) > 0) and (
type(self.model[tuple([BEGIN] * state_size)]) == list
isinstance(self.model[tuple([BEGIN] * state_size)], (tuple, list))
)
if not self.compiled:
self.precompute_begin_state()

def compile(self, inplace=False):
def compile(self, inplace: bool = False) -> "Chain":
if self.compiled:
if inplace:
return self
return Chain(None, self.state_size, model=copy.deepcopy(self.model))
mdict = {
state: compile_next(next_dict) for (state, next_dict) in self.model.items()
model = cast(ModelCompiled, self.model)
return Chain(None, self.state_size, model=copy.deepcopy(model))
model = cast(ModelUncompiled, self.model)
mdict: ModelCompiled = {
state: compile_next(next_dict) for (state, next_dict) in model.items()
}
if not inplace:
return Chain(None, self.state_size, model=mdict)
self.model = mdict
self.compiled = True
return self

def build(self, corpus, state_size):
def build(self, corpus: Iterable[List[str]], state_size: int) -> Model:
"""
Build a Python representation of the Markov model. Returns a dict
of dicts where the keys of the outer dict represent all possible states,
Expand All @@ -77,7 +124,7 @@ def build(self, corpus, state_size):

# Using a DefaultDict here would be a lot more convenient, however the memory
# usage is far higher.
model = {}
model: Model = {}

for run in corpus:
items = ([BEGIN] * state_size) + run + [END]
Expand All @@ -93,33 +140,40 @@ def build(self, corpus, state_size):
model[state][follow] += 1
return model

def precompute_begin_state(self):
def precompute_begin_state(self) -> None:
"""
Caches the summation calculation and available choices for BEGIN * state_size.
Significantly speeds up chain generation on large corpora. Thanks, @schollz!
"""
model = cast(ModelUncompiled, self.model)
begin_state = tuple([BEGIN] * self.state_size)
choices, cumdist = compile_next(self.model[begin_state])
choices, cumdist = compile_next(model[begin_state])
self.begin_cumdist = cumdist
self.begin_choices = choices

def move(self, state):
def move(self, state: State) -> str:
"""
Given a state, choose the next item at random.
"""
if self.compiled:
choices, cumdist = self.model[state]
model = cast(ModelCompiled, self.model)
choices, cumdist = model[state]
elif state == tuple([BEGIN] * self.state_size):
choices = self.begin_choices
cumdist = self.begin_cumdist
else:
choices, weights = zip(*self.model[state].items())
model = cast(ModelUncompiled, self.model)
choices = tuple(model[state].keys())
weights = tuple(model[state].values())
cumdist = list(accumulate(weights))
r = random.random() * cumdist[-1]
selection = choices[bisect.bisect(cumdist, r)]
return selection

def gen(self, init_state=None):
def gen(
self,
init_state: Union[State, None] = None,
) -> Generator[str, None, None]:
"""
Starting either with a naive BEGIN state, or the provided `init_state`
(as a tuple), return a generator that will yield successive items
Expand All @@ -133,22 +187,25 @@ def gen(self, init_state=None):
yield next_word
state = tuple(state[1:]) + (next_word,)

def walk(self, init_state=None):
def walk(self, init_state: Union[State, None] = None) -> List[str]:
"""
Return a list representing a single run of the Markov model, either
starting with a naive BEGIN state, or the provided `init_state`
(as a tuple).
"""
return list(self.gen(init_state))

def to_json(self):
def to_json(self) -> str:
"""
Dump the model as a JSON object, for loading later.
"""
return json.dumps(list(self.model.items()))

@classmethod
def from_json(cls, json_thing):
def from_json(
cls: Type[ChainT],
json_thing: Union[str, Dict, List],
) -> ChainT:
"""
Given a JSON object or JSON string that was created by `self.to_json`,
return the corresponding markovify.Chain.
Expand All @@ -165,5 +222,12 @@ def from_json(cls, json_thing):

state_size = len(list(rehydrated.keys())[0])

compiled = (len(rehydrated) > 0) and (
isinstance(rehydrated[tuple([BEGIN] * state_size)], list)
)
if compiled:
for state in rehydrated:
rehydrated[state] = tuple(rehydrated[state])

inst = cls(None, state_size, rehydrated)
return inst
7 changes: 4 additions & 3 deletions markovify/splitters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from typing import List

uppercase_letter_pat = re.compile(r"^[A-Z]$", re.UNICODE)
initialism_pat = re.compile(r"^[A-Za-z0-9]{1,2}(\.[A-Za-z0-9]{1,2})+\.$", re.UNICODE)
Expand All @@ -22,7 +23,7 @@
abbr_lowercase = "etc|v|vs|viz|al|pct".split("|")


def is_abbreviation(dotted_word):
def is_abbreviation(dotted_word: str) -> bool:
clipped = dotted_word[:-1]
if re.match(uppercase_letter_pat, clipped[0]):
if len(clipped) == 1: # Initial
Expand All @@ -38,7 +39,7 @@ def is_abbreviation(dotted_word):
return False


def is_sentence_ender(word):
def is_sentence_ender(word: str) -> bool:
if re.match(initialism_pat, word) is not None:
return False
if word[-1] in ["?", "!"]:
Expand All @@ -50,7 +51,7 @@ def is_sentence_ender(word):
return False


def split_into_sentences(text):
def split_into_sentences(text: str) -> List[str]:
potential_end_pat = re.compile(
r"".join(
[
Expand Down
Loading