diff --git a/markovify/chain.py b/markovify/chain.py index b19e8e2..fef8ae5 100644 --- a/markovify/chain.py +++ b/markovify/chain.py @@ -118,6 +118,20 @@ def move(self, state): r = random.random() * cumdist[-1] selection = choices[bisect.bisect(cumdist, r)] return selection + + def greedy_move(self, state): + """ + Given a state, choose the most likely next item + """ + if self.compiled: + choices, cumdist = self.model[state] + elif state == tuple([BEGIN] * self.state_size): + choices = self.begin_choices + cumdist = self.begin_cumdist + else: + choices, weights = zip(*self.model[state].items()) + cumdist = list(accumulate(weights)) + return choices[0] def gen(self, init_state=None): """