@@ -115,28 +115,26 @@ def split_parse(self, text, return_tokens=False, verbose=False):
115
115
116
116
preds = self .drp .predict (tokens , load_weights = True )
117
117
118
- return preds
119
-
120
118
# If tokens argument passed, return the labelled tokens
121
119
122
- # if return_tokens:
120
+ if return_tokens :
123
121
124
- # flat_predictions = list(itertools.chain.from_iterable( preds))
125
- # flat_X = list(itertools.chain.from_iterable(tokens))
126
- # rows = [i for i in zip(flat_X, flat_predictions )]
122
+ flat_preds_list = list (map ( itertools .chain .from_iterable , preds ))
123
+ flat_X = list (itertools .chain .from_iterable (tokens ))
124
+ rows = [i for i in zip (* [ flat_X ] + flat_preds_list )]
127
125
128
- # if verbose:
126
+ if verbose :
129
127
130
- # msg.divider("Token Results")
128
+ msg .divider ("Token Results" )
131
129
132
- # header = ( "token", "label")
133
- # aligns = ( "r", "l")
134
- # formatted = wasabi.table(
135
- # rows, header=header, divider=True, aligns=aligns
136
- # )
137
- # print(formatted)
130
+ header = tuple ([ "token" ] + [ "label" ] * len ( flat_preds_list ) )
131
+ aligns = tuple ([ "r" ] + [ "l" ] * len ( flat_preds_list ) )
132
+ formatted = wasabi .table (
133
+ rows , header = header , divider = True , aligns = aligns
134
+ )
135
+ print (formatted )
138
136
139
- # out = rows
137
+ out = rows
140
138
141
139
#else:
142
140
@@ -185,7 +183,7 @@ def split_parse(text, config_file=MULTITASK_CFG, tokens=False, outfile=None):
185
183
"""
186
184
mt = SplitParser (config_file )
187
185
if outfile :
188
- out = mt .split_parse (text , return_tokens = tokens , verbose = False )
186
+ out = mt .split_parse (text , return_tokens = tokens , verbose = True )
189
187
190
188
try :
191
189
with open (outfile , "w" ) as fb :
0 commit comments