Skip to content

Commit 77971e5

Browse files
chg: Update split_parse to deal with multiple predictions
1 parent 20afa75 commit 77971e5

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

deep_reference_parser/split_parse.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -115,28 +115,26 @@ def split_parse(self, text, return_tokens=False, verbose=False):
115115

116116
preds = self.drp.predict(tokens, load_weights=True)
117117

118-
return preds
119-
120118
# If tokens argument passed, return the labelled tokens
121119

122-
#if return_tokens:
120+
if return_tokens:
123121

124-
# flat_predictions = list(itertools.chain.from_iterable(preds))
125-
# flat_X = list(itertools.chain.from_iterable(tokens))
126-
# rows = [i for i in zip(flat_X, flat_predictions)]
122+
flat_preds_list = list(map(itertools.chain.from_iterable,preds))
123+
flat_X = list(itertools.chain.from_iterable(tokens))
124+
rows = [i for i in zip(*[flat_X] + flat_preds_list)]
127125

128-
# if verbose:
126+
if verbose:
129127

130-
# msg.divider("Token Results")
128+
msg.divider("Token Results")
131129

132-
# header = ("token", "label")
133-
# aligns = ("r", "l")
134-
# formatted = wasabi.table(
135-
# rows, header=header, divider=True, aligns=aligns
136-
# )
137-
# print(formatted)
130+
header = tuple(["token"] + ["label"] * len(flat_preds_list))
131+
aligns = tuple(["r"] + ["l"] * len(flat_preds_list))
132+
formatted = wasabi.table(
133+
rows, header=header, divider=True, aligns=aligns
134+
)
135+
print(formatted)
138136

139-
# out = rows
137+
out = rows
140138

141139
#else:
142140

@@ -185,7 +183,7 @@ def split_parse(text, config_file=MULTITASK_CFG, tokens=False, outfile=None):
185183
"""
186184
mt = SplitParser(config_file)
187185
if outfile:
188-
out = mt.split_parse(text, return_tokens=tokens, verbose=False)
186+
out = mt.split_parse(text, return_tokens=tokens, verbose=True)
189187

190188
try:
191189
with open(outfile, "w") as fb:

0 commit comments

Comments
 (0)