1
+ import asyncio
1
2
import functools
2
3
import os
3
4
import signal
7
8
import warnings
8
9
from contextlib import contextmanager
9
10
from pathlib import Path
10
- from typing import Any , Callable , Dict , List , Optional
11
+ from typing import Any , Callable , Dict , List , Optional , Union
11
12
12
13
import openai
13
14
import pytest
@@ -476,7 +477,8 @@ async def completions_with_server_args(
476
477
server_cli_args : List [str ],
477
478
num_logprobs : Optional [int ],
478
479
max_wait_seconds : int = 240 ,
479
- ) -> Completion :
480
+ max_tokens : Union [int , list ] = 5 ,
481
+ ) -> List [Completion ]:
480
482
'''Construct a remote OpenAI server, obtain an async client to the
481
483
server & invoke the completions API to obtain completions.
482
484
@@ -487,37 +489,49 @@ async def completions_with_server_args(
487
489
num_logprobs: Number of logprobs to report (or `None`)
488
490
max_wait_seconds: timeout interval for bringing up server.
489
491
Default: 240sec
492
+ max_tokens: max_tokens value for each of the given input prompts.
493
+ if only one max_token value is given, the same value is used
494
+ for all the prompts.
490
495
491
496
Returns:
492
497
OpenAI Completion instance
493
498
'''
494
499
500
+ if isinstance (max_tokens , int ):
501
+ max_tokens = [max_tokens ] * len (prompts )
502
+
503
+ assert len (max_tokens ) == len (prompts )
504
+
495
505
outputs = None
496
506
max_wait_seconds = 240 * 3 # 240 is default
497
507
with RemoteOpenAIServer (model_name ,
498
508
server_cli_args ,
499
509
max_wait_seconds = max_wait_seconds ) as server :
500
510
client = server .get_async_client ()
501
- outputs = await client .completions .create (model = model_name ,
502
- prompt = prompts ,
503
- temperature = 0 ,
504
- stream = False ,
505
- max_tokens = 5 ,
506
- logprobs = num_logprobs )
511
+ outputs = [ client .completions .create (model = model_name ,
512
+ prompt = [p ],
513
+ temperature = 0 ,
514
+ stream = False ,
515
+ max_tokens = max_tok ,
516
+ logprobs = num_logprobs ) \
517
+ for p , max_tok in zip (prompts , max_tokens ) ]
518
+ outputs = await asyncio .gather (* outputs )
519
+
507
520
assert outputs is not None , "Completion API call failed."
508
521
509
522
return outputs
510
523
511
524
512
- def get_client_text_generations (completions : Completion ) -> List [str ]:
525
+ def get_client_text_generations (completions : List [ Completion ] ) -> List [str ]:
513
526
'''Extract generated tokens from the output of a
514
527
request made to an Open-AI-protocol completions endpoint.
515
528
'''
516
- return [x .text for x in completions .choices ]
529
+ assert all ([len (x .choices ) == 1 for x in completions ])
530
+ return [x .choices [0 ].text for x in completions ]
517
531
518
532
519
533
def get_client_text_logprob_generations (
520
- completions : Completion ) -> List [TextTextLogprobs ]:
534
+ completions : List [ Completion ] ) -> List [TextTextLogprobs ]:
521
535
'''Operates on the output of a request made to an Open-AI-protocol
522
536
completions endpoint; obtains top-rank logprobs for each token in
523
537
each :class:`SequenceGroup`
@@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
526
540
text = '' .join (text_generations )
527
541
return [(text_generations , text ,
528
542
(None if x .logprobs is None else x .logprobs .top_logprobs ))
529
- for x in completions .choices ]
543
+ for completion in completions for x in completion .choices ]
0 commit comments