diff --git a/interpreter/llm/setup_openai_coding_llm.py b/interpreter/llm/setup_openai_coding_llm.py index 763dc4f181..da38068634 100644 --- a/interpreter/llm/setup_openai_coding_llm.py +++ b/interpreter/llm/setup_openai_coding_llm.py @@ -99,6 +99,9 @@ def coding_llm(messages): for chunk in response: + if interpreter.debug_mode: + print("Chunk from LLM", chunk) + if ('choices' not in chunk or len(chunk['choices']) == 0): # This happens sometimes continue @@ -108,31 +111,60 @@ def coding_llm(messages): # Accumulate deltas accumulated_deltas = merge_deltas(accumulated_deltas, delta) + if interpreter.debug_mode: + print("Accumulated deltas", accumulated_deltas) + if "content" in delta and delta["content"]: yield {"message": delta["content"]} if ("function_call" in accumulated_deltas and "arguments" in accumulated_deltas["function_call"]): - arguments = accumulated_deltas["function_call"]["arguments"] - arguments = parse_partial_json(arguments) - - if arguments: - - if (language is None - and "language" in arguments - and "code" in arguments # <- This ensures we're *finished* typing language, as opposed to partially done - and arguments["language"]): - language = arguments["language"] + if ("name" in accumulated_deltas["function_call"] and accumulated_deltas["function_call"]["name"] == "execute"): + arguments = accumulated_deltas["function_call"]["arguments"] + arguments = parse_partial_json(arguments) + + if arguments: + if (language is None + and "language" in arguments + and "code" in arguments # <- This ensures we're *finished* typing language, as opposed to partially done + and arguments["language"]): + language = arguments["language"] + yield {"language": language} + + if language is not None and "code" in arguments: + # Calculate the delta (new characters only) + code_delta = arguments["code"][len(code):] + # Update the code + code = arguments["code"] + # Yield the delta + if code_delta: + yield {"code": code_delta} + else: + if interpreter.debug_mode: + print("Arguments not a dict.") + + # 3.5 REALLY likes to halucinate a function named `python` and you can't really fix that, it seems. + # We just need to deal with it. + elif ("name" in accumulated_deltas["function_call"] and accumulated_deltas["function_call"]["name"] == "python"): + if interpreter.debug_mode: + print("Got direct python call") + if (language is None): + language = "python" yield {"language": language} - - if language is not None and "code" in arguments: - # Calculate the delta (new characters only) - code_delta = arguments["code"][len(code):] + + if language is not None: + # Pull the code string straight out of the "arguments" string + code_delta = accumulated_deltas["function_call"]["arguments"][len(code):] # Update the code - code = arguments["code"] + code = accumulated_deltas["function_call"]["arguments"] # Yield the delta if code_delta: - yield {"code": code_delta} - + yield {"code": code_delta} + + else: + if interpreter.debug_mode: + print("GOT BAD FUNCTION CALL: ", accumulated_deltas["function_call"]) + + return coding_llm \ No newline at end of file