diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000..00d7e68474 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,5 @@ +# Migrate code style to Black +a463f15ffd9dcc36af4829b78b5e2a81d304f500 + +# Rerun formatting after merge conflict resolution +b6fac19ce47b0f16211948b907d8b3c55dcee985 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..b664c41393 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + # Using this mirror lets us use mypyc-compiled black, which is about 2x faster + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.10.1 + hooks: + - id: black + # It is recommended to specify the latest version of Python + # supported by your project here, or alternatively use + # pre-commit's default_language_version, see + # https://pre-commit.com/#top_level-default_language_version + language_version: python3.11 + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 38ea963cd8..6b9007bebc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -45,10 +45,16 @@ Once you've forked the code and created a new branch for your work, you can run After modifying the source code, you will need to do `poetry run interpreter` again. +**Note**: This project uses [`black`](https://black.readthedocs.io/en/stable/index.html) and [`isort`](https://pypi.org/project/isort/) via a [`pre-commit`](https://pre-commit.com/) hook to ensure consistent code style. If you need to bypass it for some reason, you can `git commit` with the `--no-verify` flag. + ### Installing New Packages If you wish to install new dependencies into the project, please use `poetry add package-name`. +#### Installing Developer Dependencies + +If you need to install dependencies specific to development, like testing tools, formatting tools, etc. please use `poetry add package-name --group dev`. + ### Known Issues For some, `poetry install` might hang on some dependencies. As a first step, try to run the following command in your terminal: diff --git a/interpreter/__init__.py b/interpreter/__init__.py index 6ca64ffaa0..b92a2f59c0 100644 --- a/interpreter/__init__.py +++ b/interpreter/__init__.py @@ -1,6 +1,7 @@ -from .core.core import Interpreter import sys +from .core.core import Interpreter + # This is done so when users `import interpreter`, # they get an instance of interpreter: @@ -11,9 +12,9 @@ # But I think it saves a step, removes friction, and looks good. -# ____ ____ __ __ +# ____ ____ __ __ # / __ \____ ___ ____ / _/___ / /____ _________ ________ / /____ _____ # / / / / __ \/ _ \/ __ \ / // __ \/ __/ _ \/ ___/ __ \/ ___/ _ \/ __/ _ \/ ___/ -# / /_/ / /_/ / __/ / / / _/ // / / / /_/ __/ / / /_/ / / / __/ /_/ __/ / -# \____/ .___/\___/_/ /_/ /___/_/ /_/\__/\___/_/ / .___/_/ \___/\__/\___/_/ -# /_/ /_/ \ No newline at end of file +# / /_/ / /_/ / __/ / / / _/ // / / / /_/ __/ / / /_/ / / / __/ /_/ __/ / +# \____/ .___/\___/_/ /_/ /___/_/ /_/\__/\___/_/ / .___/_/ \___/\__/\___/_/ +# /_/ /_/ diff --git a/interpreter/cli/cli.py b/interpreter/cli/cli.py index 91c16de18f..7ba0440a42 100644 --- a/interpreter/cli/cli.py +++ b/interpreter/cli/cli.py @@ -1,12 +1,14 @@ import argparse -import subprocess import os import platform -import pkg_resources -import ooba +import subprocess + import appdirs -from ..utils.get_config import get_config_path +import ooba +import pkg_resources + from ..terminal_interface.conversation_navigator import conversation_navigator +from ..utils.get_config import get_config_path arguments = [ { @@ -15,7 +17,12 @@ "help_text": "prompt / custom instructions for the language model", "type": str, }, - {"name": "local", "nickname": "l", "help_text": "run the language model locally (experimental)", "type": bool}, + { + "name": "local", + "nickname": "l", + "help_text": "run the language model locally (experimental)", + "type": bool, + }, { "name": "auto_run", "nickname": "y", @@ -76,7 +83,7 @@ "help_text": "optionally enable safety mechanisms like code scanning; valid options are off, ask, and auto", "type": str, "choices": ["off", "ask", "auto"], - "default": "off" + "default": "off", }, { "name": "gguf_quality", @@ -148,10 +155,10 @@ def cli(interpreter): help="get Open Interpreter's version number", ) parser.add_argument( - '--change_local_device', - dest='change_local_device', - action='store_true', - help="change the device used for local execution (if GPU fails, will use CPU)" + "--change_local_device", + dest="change_local_device", + action="store_true", + help="change the device used for local execution (if GPU fails, will use CPU)", ) # TODO: Implement model explorer @@ -204,7 +211,9 @@ def cli(interpreter): setattr(interpreter, attr_name, attr_value) # if safe_mode and auto_run are enabled, safe_mode disables auto_run - if interpreter.auto_run and (interpreter.safe_mode == "ask" or interpreter.safe_mode == "auto"): + if interpreter.auto_run and ( + interpreter.safe_mode == "ask" or interpreter.safe_mode == "auto" + ): setattr(interpreter, "auto_run", False) # Default to Mistral if --local is on but --model is unset @@ -223,7 +232,9 @@ def cli(interpreter): return if args.change_local_device: - print("This will uninstall the experimental local LLM interface (Ooba) in order to reinstall it for a new local device. Proceed? (y/n)") + print( + "This will uninstall the experimental local LLM interface (Ooba) in order to reinstall it for a new local device. Proceed? (y/n)" + ) if input().lower() == "n": return @@ -237,11 +248,13 @@ def cli(interpreter): gpu_choice = input("> ").upper() - while gpu_choice not in 'ABCDN': + while gpu_choice not in "ABCDN": print("Invalid choice. Please try again.") gpu_choice = input("> ").upper() - ooba.install(force_reinstall=True, gpu_choice=gpu_choice, verbose=args.debug_mode) + ooba.install( + force_reinstall=True, gpu_choice=gpu_choice, verbose=args.debug_mode + ) return # Deprecated --fast diff --git a/interpreter/code_interpreters/base_code_interpreter.py b/interpreter/code_interpreters/base_code_interpreter.py index 9d672aa0cc..23796e4240 100644 --- a/interpreter/code_interpreters/base_code_interpreter.py +++ b/interpreter/code_interpreters/base_code_interpreter.py @@ -1,9 +1,8 @@ - - class BaseCodeInterpreter: """ .run is a generator that yields a dict with attributes: active_line, output """ + def __init__(self): pass @@ -11,4 +10,4 @@ def run(self, code): pass def terminate(self): - pass \ No newline at end of file + pass diff --git a/interpreter/code_interpreters/create_code_interpreter.py b/interpreter/code_interpreters/create_code_interpreter.py index e18db43efd..89c3dfb158 100644 --- a/interpreter/code_interpreters/create_code_interpreter.py +++ b/interpreter/code_interpreters/create_code_interpreter.py @@ -1,5 +1,6 @@ from .language_map import language_map + def create_code_interpreter(language): # Case in-sensitive language = language.lower() diff --git a/interpreter/code_interpreters/language_map.py b/interpreter/code_interpreters/language_map.py index e93d3a7c44..fa955c9001 100644 --- a/interpreter/code_interpreters/language_map.py +++ b/interpreter/code_interpreters/language_map.py @@ -1,11 +1,10 @@ -from .languages.python import Python -from .languages.shell import Shell -from .languages.javascript import JavaScript -from .languages.html import HTML from .languages.applescript import AppleScript -from .languages.r import R +from .languages.html import HTML +from .languages.javascript import JavaScript from .languages.powershell import PowerShell - +from .languages.python import Python +from .languages.r import R +from .languages.shell import Shell language_map = { "python": Python, diff --git a/interpreter/code_interpreters/languages/applescript.py b/interpreter/code_interpreters/languages/applescript.py index 403c3e0fc3..80446fbf55 100644 --- a/interpreter/code_interpreters/languages/applescript.py +++ b/interpreter/code_interpreters/languages/applescript.py @@ -1,13 +1,15 @@ import os + from ..subprocess_code_interpreter import SubprocessCodeInterpreter + class AppleScript(SubprocessCodeInterpreter): file_extension = "applescript" proper_name = "AppleScript" def __init__(self): super().__init__() - self.start_cmd = os.environ.get('SHELL', '/bin/zsh') + self.start_cmd = os.environ.get("SHELL", "/bin/zsh") def preprocess_code(self, code): """ @@ -17,17 +19,17 @@ def preprocess_code(self, code): code = self.add_active_line_indicators(code) # Escape double quotes - code = code.replace('"', r'\"') - + code = code.replace('"', r"\"") + # Wrap in double quotes code = '"' + code + '"' - + # Prepend start command for AppleScript code = "osascript -e " + code # Append end of execution indicator code += '; echo "## end_of_execution ##"' - + return code def add_active_line_indicators(self, code): @@ -35,7 +37,7 @@ def add_active_line_indicators(self, code): Adds log commands to indicate the active line of execution in the AppleScript. """ modified_lines = [] - lines = code.split('\n') + lines = code.split("\n") for idx, line in enumerate(lines): # Add log command to indicate the line number @@ -43,7 +45,7 @@ def add_active_line_indicators(self, code): modified_lines.append(f'log "## active_line {idx + 1} ##"') modified_lines.append(line) - return '\n'.join(modified_lines) + return "\n".join(modified_lines) def detect_active_line(self, line): """ @@ -61,4 +63,4 @@ def detect_end_of_execution(self, line): """ Detects end of execution marker in the output. """ - return "## end_of_execution ##" in line \ No newline at end of file + return "## end_of_execution ##" in line diff --git a/interpreter/code_interpreters/languages/html.py b/interpreter/code_interpreters/languages/html.py index 965b38717b..6a12102e8d 100644 --- a/interpreter/code_interpreters/languages/html.py +++ b/interpreter/code_interpreters/languages/html.py @@ -1,8 +1,10 @@ -import webbrowser -import tempfile import os +import tempfile +import webbrowser + from ..base_code_interpreter import BaseCodeInterpreter + class HTML(BaseCodeInterpreter): file_extension = "html" proper_name = "HTML" @@ -16,6 +18,8 @@ def run(self, code): f.write(code.encode()) # Open the HTML file with the default web browser - webbrowser.open('file://' + os.path.realpath(f.name)) + webbrowser.open("file://" + os.path.realpath(f.name)) - yield {"output": f"Saved to {os.path.realpath(f.name)} and opened with the user's default web browser."} \ No newline at end of file + yield { + "output": f"Saved to {os.path.realpath(f.name)} and opened with the user's default web browser." + } diff --git a/interpreter/code_interpreters/languages/javascript.py b/interpreter/code_interpreters/languages/javascript.py index d5e74ff824..a6279a3f2e 100644 --- a/interpreter/code_interpreters/languages/javascript.py +++ b/interpreter/code_interpreters/languages/javascript.py @@ -1,6 +1,8 @@ -from ..subprocess_code_interpreter import SubprocessCodeInterpreter import re +from ..subprocess_code_interpreter import SubprocessCodeInterpreter + + class JavaScript(SubprocessCodeInterpreter): file_extension = "js" proper_name = "JavaScript" @@ -8,10 +10,10 @@ class JavaScript(SubprocessCodeInterpreter): def __init__(self): super().__init__() self.start_cmd = "node -i" - + def preprocess_code(self, code): return preprocess_javascript(code) - + def line_postprocessor(self, line): # Node's interactive REPL outputs a billion things # So we clean it up: @@ -20,7 +22,7 @@ def line_postprocessor(self, line): if line.strip() in ["undefined", 'Type ".help" for more information.']: return None # Remove trailing ">"s - line = re.sub(r'^\s*(>\s*)+', '', line) + line = re.sub(r"^\s*(>\s*)+", "", line) return line def detect_active_line(self, line): @@ -30,7 +32,7 @@ def detect_active_line(self, line): def detect_end_of_execution(self, line): return "## end_of_execution ##" in line - + def preprocess_javascript(code): """ @@ -61,4 +63,4 @@ def preprocess_javascript(code): console.log("## end_of_execution ##"); """ - return processed_code \ No newline at end of file + return processed_code diff --git a/interpreter/code_interpreters/languages/powershell.py b/interpreter/code_interpreters/languages/powershell.py index a5ff774c31..d96615e18a 100644 --- a/interpreter/code_interpreters/languages/powershell.py +++ b/interpreter/code_interpreters/languages/powershell.py @@ -1,7 +1,9 @@ -import platform import os +import platform + from ..subprocess_code_interpreter import SubprocessCodeInterpreter + class PowerShell(SubprocessCodeInterpreter): file_extension = "ps1" proper_name = "PowerShell" @@ -10,11 +12,11 @@ def __init__(self): super().__init__() # Determine the start command based on the platform (use "powershell" for Windows) - if platform.system() == 'Windows': - self.start_cmd = 'powershell.exe' - #self.start_cmd = os.environ.get('SHELL', 'powershell.exe') + if platform.system() == "Windows": + self.start_cmd = "powershell.exe" + # self.start_cmd = os.environ.get('SHELL', 'powershell.exe') else: - self.start_cmd = os.environ.get('SHELL', 'bash') + self.start_cmd = os.environ.get("SHELL", "bash") def preprocess_code(self, code): return preprocess_powershell(code) @@ -30,6 +32,7 @@ def detect_active_line(self, line): def detect_end_of_execution(self, line): return "## end_of_execution ##" in line + def preprocess_powershell(code): """ Add active line markers @@ -47,15 +50,17 @@ def preprocess_powershell(code): return code + def add_active_line_prints(code): """ Add Write-Output statements indicating line numbers to a PowerShell script. """ - lines = code.split('\n') + lines = code.split("\n") for index, line in enumerate(lines): # Insert the Write-Output command before the actual line lines[index] = f'Write-Output "## active_line {index + 1} ##"\n{line}' - return '\n'.join(lines) + return "\n".join(lines) + def wrap_in_try_catch(code): """ @@ -65,4 +70,4 @@ def wrap_in_try_catch(code): try { $ErrorActionPreference = "Stop" """ - return try_catch_code + code + "\n} catch {\n Write-Error $_\n}\n" \ No newline at end of file + return try_catch_code + code + "\n} catch {\n Write-Error $_\n}\n" diff --git a/interpreter/code_interpreters/languages/python.py b/interpreter/code_interpreters/languages/python.py index 8747a42661..7e8e107902 100644 --- a/interpreter/code_interpreters/languages/python.py +++ b/interpreter/code_interpreters/languages/python.py @@ -1,9 +1,11 @@ -import os -import sys -from ..subprocess_code_interpreter import SubprocessCodeInterpreter import ast +import os import re import shlex +import sys + +from ..subprocess_code_interpreter import SubprocessCodeInterpreter + class Python(SubprocessCodeInterpreter): file_extension = "py" @@ -12,15 +14,15 @@ class Python(SubprocessCodeInterpreter): def __init__(self): super().__init__() executable = sys.executable - if os.name != 'nt': # not Windows + if os.name != "nt": # not Windows executable = shlex.quote(executable) self.start_cmd = executable + " -i -q -u" - + def preprocess_code(self, code): return preprocess_python(code) - + def line_postprocessor(self, line): - if re.match(r'^(\s*>>>\s*|\s*\.\.\.\s*)', line): + if re.match(r"^(\s*>>>\s*|\s*\.\.\.\s*)", line): return None return line @@ -31,7 +33,7 @@ def detect_active_line(self, line): def detect_end_of_execution(self, line): return "## end_of_execution ##" in line - + def preprocess_python(code): """ @@ -78,9 +80,9 @@ def insert_print_statement(self, line_number): """Inserts a print statement for a given line number.""" return ast.Expr( value=ast.Call( - func=ast.Name(id='print', ctx=ast.Load()), + func=ast.Name(id="print", ctx=ast.Load()), args=[ast.Constant(value=f"## active_line {line_number} ##")], - keywords=[] + keywords=[], ) ) @@ -93,7 +95,7 @@ def process_body(self, body): body = [body] for sub_node in body: - if hasattr(sub_node, 'lineno'): + if hasattr(sub_node, "lineno"): new_body.append(self.insert_print_statement(sub_node.lineno)) new_body.append(sub_node) @@ -104,11 +106,11 @@ def visit(self, node): new_node = super().visit(node) # If node has a body, process it - if hasattr(new_node, 'body'): + if hasattr(new_node, "body"): new_node.body = self.process_body(new_node.body) # If node has an orelse block (like in for, while, if), process it - if hasattr(new_node, 'orelse') and new_node.orelse: + if hasattr(new_node, "orelse") and new_node.orelse: new_node.orelse = self.process_body(new_node.orelse) # Special case for Try nodes as they have multiple blocks @@ -119,7 +121,7 @@ def visit(self, node): new_node.finalbody = self.process_body(new_node.finalbody) return new_node - + def wrap_in_try_except(code): # Add import traceback @@ -138,16 +140,20 @@ def wrap_in_try_except(code): body=[ ast.Expr( value=ast.Call( - func=ast.Attribute(value=ast.Name(id="traceback", ctx=ast.Load()), attr="print_exc", ctx=ast.Load()), + func=ast.Attribute( + value=ast.Name(id="traceback", ctx=ast.Load()), + attr="print_exc", + ctx=ast.Load(), + ), args=[], - keywords=[] + keywords=[], ) ), - ] + ], ) ], orelse=[], - finalbody=[] + finalbody=[], ) # Assign the try-except block as the new body diff --git a/interpreter/code_interpreters/languages/r.py b/interpreter/code_interpreters/languages/r.py index 16f51f93cf..1500f8c1b1 100644 --- a/interpreter/code_interpreters/languages/r.py +++ b/interpreter/code_interpreters/languages/r.py @@ -1,6 +1,8 @@ -from ..subprocess_code_interpreter import SubprocessCodeInterpreter import re +from ..subprocess_code_interpreter import SubprocessCodeInterpreter + + class R(SubprocessCodeInterpreter): file_extension = "r" proper_name = "R" @@ -8,7 +10,7 @@ class R(SubprocessCodeInterpreter): def __init__(self): super().__init__() self.start_cmd = "R -q --vanilla" # Start R in quiet and vanilla mode - + def preprocess_code(self, code): """ Add active line markers @@ -38,22 +40,26 @@ def preprocess_code(self, code): # Count the number of lines of processed_code # (R echoes all code back for some reason, but we can skip it if we track this!) self.code_line_count = len(processed_code.split("\n")) - 1 - + return processed_code - + def line_postprocessor(self, line): # If the line count attribute is set and non-zero, decrement and skip the line if hasattr(self, "code_line_count") and self.code_line_count > 0: self.code_line_count -= 1 return None - if re.match(r'^(\s*>>>\s*|\s*\.\.\.\s*|\s*>\s*|\s*\+\s*|\s*)$', line): + if re.match(r"^(\s*>>>\s*|\s*\.\.\.\s*|\s*>\s*|\s*\+\s*|\s*)$", line): return None if "R version" in line: # Startup message return None - if line.strip().startswith("[1] \"") and line.endswith("\""): # For strings, trim quotation marks + if line.strip().startswith('[1] "') and line.endswith( + '"' + ): # For strings, trim quotation marks return line[5:-1].strip() - if line.strip().startswith("[1]"): # Normal R output prefix for non-string outputs + if line.strip().startswith( + "[1]" + ): # Normal R output prefix for non-string outputs return line[4:].strip() return line diff --git a/interpreter/code_interpreters/languages/shell.py b/interpreter/code_interpreters/languages/shell.py index 136160dbd0..efaa1a94e8 100644 --- a/interpreter/code_interpreters/languages/shell.py +++ b/interpreter/code_interpreters/languages/shell.py @@ -1,6 +1,8 @@ +import os import platform + from ..subprocess_code_interpreter import SubprocessCodeInterpreter -import os + class Shell(SubprocessCodeInterpreter): file_extension = "sh" @@ -10,14 +12,14 @@ def __init__(self): super().__init__() # Determine the start command based on the platform - if platform.system() == 'Windows': - self.start_cmd = 'cmd.exe' + if platform.system() == "Windows": + self.start_cmd = "cmd.exe" else: - self.start_cmd = os.environ.get('SHELL', 'bash') + self.start_cmd = os.environ.get("SHELL", "bash") def preprocess_code(self, code): return preprocess_shell(code) - + def line_postprocessor(self, line): return line @@ -28,7 +30,7 @@ def detect_active_line(self, line): def detect_end_of_execution(self, line): return "## end_of_execution ##" in line - + def preprocess_shell(code): """ @@ -36,13 +38,13 @@ def preprocess_shell(code): Wrap in a try except (trap in shell) Add end of execution marker """ - + # Add commands that tell us what the active line is code = add_active_line_prints(code) - + # Add end command (we'll be listening for this so we know when it ends) code += '\necho "## end_of_execution ##"' - + return code @@ -50,8 +52,8 @@ def add_active_line_prints(code): """ Add echo statements indicating line numbers to a shell string. """ - lines = code.split('\n') + lines = code.split("\n") for index, line in enumerate(lines): # Insert the echo command before the actual line lines[index] = f'echo "## active_line {index + 1} ##"\n{line}' - return '\n'.join(lines) \ No newline at end of file + return "\n".join(lines) diff --git a/interpreter/code_interpreters/subprocess_code_interpreter.py b/interpreter/code_interpreters/subprocess_code_interpreter.py index 04428cece0..9dd150a986 100644 --- a/interpreter/code_interpreters/subprocess_code_interpreter.py +++ b/interpreter/code_interpreters/subprocess_code_interpreter.py @@ -1,12 +1,12 @@ - - +import queue import subprocess import threading -import queue import time import traceback + from .base_code_interpreter import BaseCodeInterpreter + class SubprocessCodeInterpreter(BaseCodeInterpreter): def __init__(self): self.start_cmd = "" @@ -17,13 +17,13 @@ def __init__(self): def detect_active_line(self, line): return None - + def detect_end_of_execution(self, line): return None - + def line_postprocessor(self, line): return line - + def preprocess_code(self, code): """ This needs to insert an end_of_execution marker of some kind, @@ -32,7 +32,7 @@ def preprocess_code(self, code): Optionally, add active line markers for detect_active_line. """ return code - + def terminate(self): self.process.terminate() @@ -40,19 +40,25 @@ def start_process(self): if self.process: self.terminate() - self.process = subprocess.Popen(self.start_cmd.split(), - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=0, - universal_newlines=True) - threading.Thread(target=self.handle_stream_output, - args=(self.process.stdout, False), - daemon=True).start() - threading.Thread(target=self.handle_stream_output, - args=(self.process.stderr, True), - daemon=True).start() + self.process = subprocess.Popen( + self.start_cmd.split(), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=0, + universal_newlines=True, + ) + threading.Thread( + target=self.handle_stream_output, + args=(self.process.stdout, False), + daemon=True, + ).start() + threading.Thread( + target=self.handle_stream_output, + args=(self.process.stderr, True), + daemon=True, + ).start() def run(self, code): retry_count = 0 @@ -66,7 +72,6 @@ def run(self, code): except: yield {"output": traceback.format_exc()} return - while retry_count <= max_retries: if self.debug_mode: @@ -113,14 +118,14 @@ def run(self, code): break def handle_stream_output(self, stream, is_error_stream): - for line in iter(stream.readline, ''): + for line in iter(stream.readline, ""): if self.debug_mode: print(f"Received output line:\n{line}\n---") line = self.line_postprocessor(line) if line is None: - continue # `line = None` is the postprocessor's signal to discard completely + continue # `line = None` is the postprocessor's signal to discard completely if self.detect_active_line(line): active_line = self.detect_active_line(line) @@ -135,4 +140,3 @@ def handle_stream_output(self, stream, is_error_stream): self.done.set() else: self.output_queue.put({"output": line}) - diff --git a/interpreter/core/core.py b/interpreter/core/core.py index beafd1f150..b4d80624d4 100644 --- a/interpreter/core/core.py +++ b/interpreter/core/core.py @@ -2,22 +2,27 @@ This file defines the Interpreter class. It's the main file. `import interpreter` will import an instance of this class. """ + +import json +import os +from datetime import datetime + +import appdirs + from interpreter.utils import display_markdown_message + from ..cli.cli import cli -from ..utils.get_config import get_config, user_config_path -from ..utils.local_storage_path import get_storage_path -from .respond import respond from ..llm.setup_llm import setup_llm +from ..rag.get_relevant_procedures_string import get_relevant_procedures_string from ..terminal_interface.terminal_interface import terminal_interface from ..terminal_interface.validate_llm_settings import validate_llm_settings -from .generate_system_message import generate_system_message -import os -from datetime import datetime -from ..rag.get_relevant_procedures_string import get_relevant_procedures_string -import json from ..utils.check_for_update import check_for_update from ..utils.display_markdown_message import display_markdown_message from ..utils.embed import embed_function +from ..utils.get_config import get_config, user_config_path +from ..utils.local_storage_path import get_storage_path +from .generate_system_message import generate_system_message +from .respond import respond class Interpreter: @@ -70,11 +75,13 @@ def __init__(self): if not self.local: # This should actually be pushed into the utility if check_for_update(): - display_markdown_message("> **A new version of Open Interpreter is available.**\n>Please run: `pip install --upgrade open-interpreter`\n\n---") + display_markdown_message( + "> **A new version of Open Interpreter is available.**\n>Please run: `pip install --upgrade open-interpreter`\n\n---" + ) def extend_config(self, config_path): if self.debug_mode: - print(f'Extending configuration from `{config_path}`') + print(f"Extending configuration from `{config_path}`") config = get_config(config_path) self.__dict__.update(config) @@ -82,15 +89,14 @@ def extend_config(self, config_path): def chat(self, message=None, display=True, stream=False): if stream: return self._streaming_chat(message=message, display=display) - + # If stream=False, *pull* from the stream. for _ in self._streaming_chat(message=message, display=display): pass - + return self.messages - - def _streaming_chat(self, message=None, display=True): + def _streaming_chat(self, message=None, display=True): # If we have a display, # we can validate our LLM settings w/ the user first if display: @@ -107,7 +113,7 @@ def _streaming_chat(self, message=None, display=True): if display: yield from terminal_interface(self, message) return - + # One-off message if message or message == "": if message == "": @@ -117,30 +123,39 @@ def _streaming_chat(self, message=None, display=True): # Save conversation if we've turned conversation_history on if self.conversation_history: - # If it's the first message, set the conversation name if not self.conversation_filename: - - first_few_words = "_".join(self.messages[0]["message"][:25].split(" ")[:-1]) - for char in "<>:\"/\\|?*!": # Invalid characters for filenames + first_few_words = "_".join( + self.messages[0]["message"][:25].split(" ")[:-1] + ) + for char in '<>:"/\\|?*!': # Invalid characters for filenames first_few_words = first_few_words.replace(char, "") date = datetime.now().strftime("%B_%d_%Y_%H-%M-%S") - self.conversation_filename = "__".join([first_few_words, date]) + ".json" + self.conversation_filename = ( + "__".join([first_few_words, date]) + ".json" + ) # Check if the directory exists, if not, create it if not os.path.exists(self.conversation_history_path): os.makedirs(self.conversation_history_path) # Write or overwrite the file - with open(os.path.join(self.conversation_history_path, self.conversation_filename), 'w') as f: + with open( + os.path.join( + self.conversation_history_path, self.conversation_filename + ), + "w", + ) as f: json.dump(self.messages, f) - + return - raise Exception("`interpreter.chat()` requires a display. Set `display=True` or pass a message into `interpreter.chat(message)`.") + raise Exception( + "`interpreter.chat()` requires a display. Set `display=True` or pass a message into `interpreter.chat(message)`." + ) def _respond(self): yield from respond(self) - + def reset(self): for code_interpreter in self._code_interpreters.values(): code_interpreter.terminate() @@ -148,14 +163,16 @@ def reset(self): # Reset the two functions below, in case the user set them self.generate_system_message = lambda: generate_system_message(self) - self.get_relevant_procedures_string = lambda: get_relevant_procedures_string(self) + self.get_relevant_procedures_string = lambda: get_relevant_procedures_string( + self + ) self.__init__() - # These functions are worth exposing to developers # I wish we could just dynamically expose all of our functions to devs... def generate_system_message(self): return generate_system_message(self) + def get_relevant_procedures_string(self): return get_relevant_procedures_string(self) diff --git a/interpreter/core/generate_system_message.py b/interpreter/core/generate_system_message.py index 0430cfdab7..41776e89d5 100644 --- a/interpreter/core/generate_system_message.py +++ b/interpreter/core/generate_system_message.py @@ -1,6 +1,8 @@ -from ..utils.get_user_info_string import get_user_info_string import traceback +from ..utils.get_user_info_string import get_user_info_string + + def generate_system_message(interpreter): """ Dynamically generate a system message. @@ -15,7 +17,6 @@ def generate_system_message(interpreter): #### Start with the static system message system_message = interpreter.system_message - #### Add dynamic components, like the user's OS, username, etc @@ -28,4 +29,4 @@ def generate_system_message(interpreter): # In case some folks can't install the embedding model (I'm not sure if this ever happens) pass - return system_message \ No newline at end of file + return system_message diff --git a/interpreter/core/respond.py b/interpreter/core/respond.py index 5cb4d9763e..226abaf317 100644 --- a/interpreter/core/respond.py +++ b/interpreter/core/respond.py @@ -1,10 +1,13 @@ +import traceback + +import litellm + from ..code_interpreters.create_code_interpreter import create_code_interpreter -from ..utils.merge_deltas import merge_deltas +from ..code_interpreters.language_map import language_map from ..utils.display_markdown_message import display_markdown_message +from ..utils.merge_deltas import merge_deltas from ..utils.truncate_output import truncate_output -from ..code_interpreters.language_map import language_map -import traceback -import litellm + def respond(interpreter): """ @@ -13,7 +16,6 @@ def respond(interpreter): """ while True: - system_message = interpreter.generate_system_message() # Create message object @@ -28,7 +30,6 @@ def respond(interpreter): if "output" in message and message["output"] == "": message["output"] = "No output" - ### RUN THE LLM ### # Add a new message from the assistant to interpreter's "messages" attribute @@ -38,12 +39,10 @@ def respond(interpreter): # Start putting chunks into the new message # + yielding chunks to the user try: - # Track the type of chunk that the coding LLM is emitting chunk_type = None for chunk in interpreter._llm(messages_for_llm): - # Add chunk to the last message interpreter.messages[-1] = merge_deltas(interpreter.messages[-1], chunk) @@ -74,31 +73,32 @@ def respond(interpreter): yield {"end_of_message": True} elif chunk_type == "code": yield {"end_of_code": True} - + except litellm.exceptions.BudgetExceededError: - display_markdown_message(f"""> Max budget exceeded + display_markdown_message( + f"""> Max budget exceeded **Session spend:** ${litellm._current_cost} **Max budget:** ${interpreter.max_budget} Press CTRL-C then run `interpreter --max_budget [higher USD amount]` to proceed. - """) + """ + ) break # Provide extra information on how to change API keys, if we encounter that error # (Many people writing GitHub issues were struggling with this) except Exception as e: - if 'auth' in str(e).lower() or 'api key' in str(e).lower(): + if "auth" in str(e).lower() or "api key" in str(e).lower(): output = traceback.format_exc() - raise Exception(f"{output}\n\nThere might be an issue with your API key(s).\n\nTo reset your API key (we'll use OPENAI_API_KEY for this example, but you may need to reset your ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, etc):\n Mac/Linux: 'export OPENAI_API_KEY=your-key-here',\n Windows: 'setx OPENAI_API_KEY your-key-here' then restart terminal.\n\n") + raise Exception( + f"{output}\n\nThere might be an issue with your API key(s).\n\nTo reset your API key (we'll use OPENAI_API_KEY for this example, but you may need to reset your ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, etc):\n Mac/Linux: 'export OPENAI_API_KEY=your-key-here',\n Windows: 'setx OPENAI_API_KEY your-key-here' then restart terminal.\n\n" + ) else: raise - - - + ### RUN CODE (if it's there) ### if "code" in interpreter.messages[-1]: - if interpreter.debug_mode: print("Running code:", interpreter.messages[-1]) @@ -107,7 +107,9 @@ def respond(interpreter): code = interpreter.messages[-1]["code"] # Fix a common error where the LLM thinks it's in a Jupyter notebook - if interpreter.messages[-1]["language"] == "python" and code.startswith("!"): + if interpreter.messages[-1]["language"] == "python" and code.startswith( + "!" + ): code = code[1:] interpreter.messages[-1]["code"] = code interpreter.messages[-1]["language"] = "shell" @@ -116,10 +118,12 @@ def respond(interpreter): language = interpreter.messages[-1]["language"] if language in language_map: if language not in interpreter._code_interpreters: - interpreter._code_interpreters[language] = create_code_interpreter(language) + interpreter._code_interpreters[ + language + ] = create_code_interpreter(language) code_interpreter = interpreter._code_interpreters[language] else: - #This still prints the code but don't allow code to run. Let's Open-Interpreter know through output message + # This still prints the code but don't allow code to run. Let's Open-Interpreter know through output message error_output = f"Error: Open Interpreter does not currently support {language}." print(error_output) @@ -163,4 +167,4 @@ def respond(interpreter): # Doesn't want to run code. We're done break - return \ No newline at end of file + return diff --git a/interpreter/llm/convert_to_coding_llm.py b/interpreter/llm/convert_to_coding_llm.py index acf323cb10..76f5107280 100644 --- a/interpreter/llm/convert_to_coding_llm.py +++ b/interpreter/llm/convert_to_coding_llm.py @@ -1,8 +1,7 @@ - - from ..utils.convert_to_openai_messages import convert_to_openai_messages from .setup_text_llm import setup_text_llm + def convert_to_coding_llm(text_llm, debug_mode=False): """ Takes a text_llm @@ -15,24 +14,23 @@ def coding_llm(messages): inside_code_block = False accumulated_block = "" language = None - - for chunk in text_llm(messages): + for chunk in text_llm(messages): if debug_mode: print("Chunk in coding_llm", chunk) - if ('choices' not in chunk or len(chunk['choices']) == 0): + if "choices" not in chunk or len(chunk["choices"]) == 0: # This happens sometimes continue - - content = chunk['choices'][0]['delta'].get('content', "") - + + content = chunk["choices"][0]["delta"].get("content", "") + accumulated_block += content if accumulated_block.endswith("`"): # We might be writing "```" one token at a time. continue - + # Did we just enter a code block? if "```" in accumulated_block and not inside_code_block: inside_code_block = True @@ -44,7 +42,6 @@ def coding_llm(messages): # If we're in a code block, if inside_code_block: - # If we don't have a `language`, find it if language is None and "\n" in accumulated_block: language = accumulated_block.split("\n")[0] @@ -53,23 +50,23 @@ def coding_llm(messages): if language == "": language = "python" else: - #Removes hallucinations containing spaces or non letters. - language = ''.join(char for char in language if char.isalpha()) + # Removes hallucinations containing spaces or non letters. + language = "".join(char for char in language if char.isalpha()) output = {"language": language} # If we recieved more than just the language in this chunk, send that if content.split("\n")[1]: output["code"] = content.split("\n")[1] - + yield output - + # If we do have a `language`, send the output as code elif language: yield {"code": content} - + # If we're not in a code block, send the output as a message if not inside_code_block: yield {"message": content} - return coding_llm \ No newline at end of file + return coding_llm diff --git a/interpreter/llm/setup_llm.py b/interpreter/llm/setup_llm.py index b2b517b15a..0ea85baaf6 100644 --- a/interpreter/llm/setup_llm.py +++ b/interpreter/llm/setup_llm.py @@ -1,10 +1,11 @@ +import os +import litellm -from .setup_text_llm import setup_text_llm from .convert_to_coding_llm import convert_to_coding_llm from .setup_openai_coding_llm import setup_openai_coding_llm -import os -import litellm +from .setup_text_llm import setup_text_llm + def setup_llm(interpreter): """ @@ -12,12 +13,14 @@ def setup_llm(interpreter): returns a Coding LLM (a generator that streams deltas with `message` and `code`). """ - if (not interpreter.local - and (interpreter.model in litellm.open_ai_chat_completion_models or interpreter.model.startswith("azure/"))): + if not interpreter.local and ( + interpreter.model in litellm.open_ai_chat_completion_models + or interpreter.model.startswith("azure/") + ): # Function calling LLM coding_llm = setup_openai_coding_llm(interpreter) else: text_llm = setup_text_llm(interpreter) coding_llm = convert_to_coding_llm(text_llm, debug_mode=interpreter.debug_mode) - return coding_llm \ No newline at end of file + return coding_llm diff --git a/interpreter/llm/setup_local_text_llm.py b/interpreter/llm/setup_local_text_llm.py index 2a193d4eb6..d5ed8bdd02 100644 --- a/interpreter/llm/setup_local_text_llm.py +++ b/interpreter/llm/setup_local_text_llm.py @@ -1,8 +1,11 @@ -from ..utils.display_markdown_message import display_markdown_message +import copy +import html + import inquirer import ooba -import html -import copy + +from ..utils.display_markdown_message import display_markdown_message + def setup_local_text_llm(interpreter): """ @@ -12,12 +15,16 @@ def setup_local_text_llm(interpreter): repo_id = interpreter.model.replace("huggingface/", "") - display_markdown_message(f"> **Warning**: Local LLM usage is an experimental, unstable feature.") + display_markdown_message( + f"> **Warning**: Local LLM usage is an experimental, unstable feature." + ) if repo_id != "TheBloke/Mistral-7B-Instruct-v0.1-GGUF": # ^ This means it was prob through the old --local, so we have already displayed this message. # Hacky. Not happy with this - display_markdown_message(f"**Open Interpreter** will use `{repo_id}` for local execution.") + display_markdown_message( + f"**Open Interpreter** will use `{repo_id}` for local execution." + ) if "gguf" in repo_id.lower() and interpreter.gguf_quality == None: gguf_quality_choices = { @@ -25,15 +32,19 @@ def setup_local_text_llm(interpreter): "Small": 0.25, "Medium": 0.5, "Large": 0.75, - "Extra Large": 1.0 + "Extra Large": 1.0, } - questions = [inquirer.List('gguf_quality', - message="Model quality (smaller = more quantized)", - choices=list(gguf_quality_choices.keys()))] - + questions = [ + inquirer.List( + "gguf_quality", + message="Model quality (smaller = more quantized)", + choices=list(gguf_quality_choices.keys()), + ) + ] + answers = inquirer.prompt(questions) - interpreter.gguf_quality = gguf_quality_choices[answers['gguf_quality']] + interpreter.gguf_quality = gguf_quality_choices[answers["gguf_quality"]] path = ooba.download(f"https://huggingface.co/{repo_id}") @@ -70,25 +81,32 @@ def local_text_llm(messages): """ # Convert messages with function calls and outputs into "assistant" and "user" calls. - # Align Mistral lol if "mistral" in repo_id.lower(): # just.. let's try a simple system message. this seems to work fine. - messages[0]["content"] = "You are Open Interpreter. You almost always run code to complete user requests. Outside code, use markdown." - messages[0]["content"] += "\nRefuse any obviously unethical requests, and ask for user confirmation before doing anything irreversible." + messages[0][ + "content" + ] = "You are Open Interpreter. You almost always run code to complete user requests. Outside code, use markdown." + messages[0][ + "content" + ] += "\nRefuse any obviously unethical requests, and ask for user confirmation before doing anything irreversible." # Tell it how to run code. # THIS MESSAGE IS DUPLICATED IN `setup_text_llm.py` # (We should deduplicate it somehow soon. perhaps in the config?) - - messages = copy.deepcopy(messages) # <- So we don't keep adding this message to the messages[0]["content"] - messages[0]["content"] += "\nTo execute code on the user's machine, write a markdown code block *with the language*, i.e:\n\n```python\nprint('Hi!')\n```\nYou will recieve the output ('Hi!'). Use any language." + + messages = copy.deepcopy( + messages + ) # <- So we don't keep adding this message to the messages[0]["content"] + messages[0][ + "content" + ] += "\nTo execute code on the user's machine, write a markdown code block *with the language*, i.e:\n\n```python\nprint('Hi!')\n```\nYou will recieve the output ('Hi!'). Use any language." if interpreter.debug_mode: print("Messages going to ooba:", messages) - buffer = '' # Hold potential entity tokens and other characters. + buffer = "" # Hold potential entity tokens and other characters. for token in ooba_llm.chat(messages): # Some models like to generate HTML Entities (like ", & ') @@ -101,10 +119,12 @@ def local_text_llm(messages): buffer += token # If there's a possible incomplete entity at the end of buffer, we delay processing. - while ('&' in buffer and ';' in buffer) or (buffer.count('&') == 1 and ';' not in buffer): + while ("&" in buffer and ";" in buffer) or ( + buffer.count("&") == 1 and ";" not in buffer + ): # Find the first complete entity in the buffer. - start_idx = buffer.find('&') - end_idx = buffer.find(';', start_idx) + start_idx = buffer.find("&") + end_idx = buffer.find(";", start_idx) # If there's no complete entity, break and await more tokens. if start_idx == -1 or end_idx == -1: @@ -113,33 +133,26 @@ def local_text_llm(messages): # Yield content before the entity. for char in buffer[:start_idx]: yield make_chunk(char) - + # Extract the entity, decode it, and yield. - entity = buffer[start_idx:end_idx + 1] + entity = buffer[start_idx : end_idx + 1] yield make_chunk(html.unescape(entity)) # Remove the processed content from the buffer. - buffer = buffer[end_idx + 1:] + buffer = buffer[end_idx + 1 :] # If there's no '&' left in the buffer, yield all of its content. - if '&' not in buffer: + if "&" not in buffer: for char in buffer: yield make_chunk(char) - buffer = '' + buffer = "" # At the end, if there's any content left in the buffer, yield it. for char in buffer: yield make_chunk(char) - + return local_text_llm + def make_chunk(token): - return { - "choices": [ - { - "delta": { - "content": token - } - } - ] - } + return {"choices": [{"delta": {"content": token}}]} diff --git a/interpreter/llm/setup_openai_coding_llm.py b/interpreter/llm/setup_openai_coding_llm.py index fa7d254f18..afe427f0d1 100644 --- a/interpreter/llm/setup_openai_coding_llm.py +++ b/interpreter/llm/setup_openai_coding_llm.py @@ -1,33 +1,37 @@ import litellm -from ..utils.merge_deltas import merge_deltas -from ..utils.parse_partial_json import parse_partial_json -from ..utils.convert_to_openai_messages import convert_to_openai_messages -from ..utils.display_markdown_message import display_markdown_message import tokentrim as tt +from ..utils.convert_to_openai_messages import convert_to_openai_messages +from ..utils.display_markdown_message import display_markdown_message +from ..utils.merge_deltas import merge_deltas +from ..utils.parse_partial_json import parse_partial_json function_schema = { - "name": "execute", - "description": - "Executes code on the user's machine, **in the users local environment**, and returns the output", - "parameters": { - "type": "object", - "properties": { - "language": { - "type": "string", - "description": - "The programming language (required parameter to the `execute` function)", - "enum": ["python", "R", "shell", "applescript", "javascript", "html", "powershell"] - }, - "code": { - "type": "string", - "description": "The code to execute (required)" - } + "name": "execute", + "description": "Executes code on the user's machine, **in the users local environment**, and returns the output", + "parameters": { + "type": "object", + "properties": { + "language": { + "type": "string", + "description": "The programming language (required parameter to the `execute` function)", + "enum": [ + "python", + "R", + "shell", + "applescript", + "javascript", + "html", + "powershell", + ], + }, + "code": {"type": "string", "description": "The code to execute (required)"}, + }, + "required": ["language", "code"], }, - "required": ["language", "code"] - }, } + def setup_openai_coding_llm(interpreter): """ Takes an Interpreter (which includes a ton of LLM settings), @@ -35,12 +39,13 @@ def setup_openai_coding_llm(interpreter): """ def coding_llm(messages): - # Convert messages messages = convert_to_openai_messages(messages, function_calling=True) # Add OpenAI's recommended function message - messages[0]["content"] += "\n\nOnly use the function you have been provided with." + messages[0][ + "content" + ] += "\n\nOnly use the function you have been provided with." # Seperate out the system_message from messages # (We expect the first message to always be a system_message) @@ -49,26 +54,38 @@ def coding_llm(messages): # Trim messages, preserving the system_message try: - messages = tt.trim(messages=messages, system_message=system_message, model=interpreter.model) + messages = tt.trim( + messages=messages, + system_message=system_message, + model=interpreter.model, + ) except: if interpreter.context_window: - messages = tt.trim(messages=messages, system_message=system_message, max_tokens=interpreter.context_window) + messages = tt.trim( + messages=messages, + system_message=system_message, + max_tokens=interpreter.context_window, + ) else: - display_markdown_message(""" + display_markdown_message( + """ **We were unable to determine the context window of this model.** Defaulting to 3000. If your model can handle more, run `interpreter --context_window {token limit}` or `interpreter.context_window = {token limit}`. - """) - messages = tt.trim(messages=messages, system_message=system_message, max_tokens=3000) + """ + ) + messages = tt.trim( + messages=messages, system_message=system_message, max_tokens=3000 + ) if interpreter.debug_mode: print("Sending this to the OpenAI LLM:", messages) # Create LiteLLM generator params = { - 'model': interpreter.model, - 'messages': messages, - 'stream': True, - 'functions': [function_schema] + "model": interpreter.model, + "messages": messages, + "stream": True, + "functions": [function_schema], } # Optional inputs @@ -100,11 +117,10 @@ def coding_llm(messages): code = "" for chunk in response: - if interpreter.debug_mode: print("Chunk from LLM", chunk) - if ('choices' not in chunk or len(chunk['choices']) == 0): + if "choices" not in chunk or len(chunk["choices"]) == 0: # This happens sometimes continue @@ -119,24 +135,31 @@ def coding_llm(messages): if "content" in delta and delta["content"]: yield {"message": delta["content"]} - if ("function_call" in accumulated_deltas - and "arguments" in accumulated_deltas["function_call"]): - - if ("name" in accumulated_deltas["function_call"] and accumulated_deltas["function_call"]["name"] == "execute"): + if ( + "function_call" in accumulated_deltas + and "arguments" in accumulated_deltas["function_call"] + ): + if ( + "name" in accumulated_deltas["function_call"] + and accumulated_deltas["function_call"]["name"] == "execute" + ): arguments = accumulated_deltas["function_call"]["arguments"] arguments = parse_partial_json(arguments) if arguments: - if (language is None + if ( + language is None and "language" in arguments - and "code" in arguments # <- This ensures we're *finished* typing language, as opposed to partially done - and arguments["language"]): + and "code" + in arguments # <- This ensures we're *finished* typing language, as opposed to partially done + and arguments["language"] + ): language = arguments["language"] yield {"language": language} - + if language is not None and "code" in arguments: # Calculate the delta (new characters only) - code_delta = arguments["code"][len(code):] + code_delta = arguments["code"][len(code) :] # Update the code code = arguments["code"] # Yield the delta @@ -147,17 +170,22 @@ def coding_llm(messages): print("Arguments not a dict.") # 3.5 REALLY likes to halucinate a function named `python` and you can't really fix that, it seems. - # We just need to deal with it. - elif ("name" in accumulated_deltas["function_call"] and accumulated_deltas["function_call"]["name"] == "python"): + # We just need to deal with it. + elif ( + "name" in accumulated_deltas["function_call"] + and accumulated_deltas["function_call"]["name"] == "python" + ): if interpreter.debug_mode: print("Got direct python call") - if (language is None): + if language is None: language = "python" yield {"language": language} if language is not None: # Pull the code string straight out of the "arguments" string - code_delta = accumulated_deltas["function_call"]["arguments"][len(code):] + code_delta = accumulated_deltas["function_call"]["arguments"][ + len(code) : + ] # Update the code code = accumulated_deltas["function_call"]["arguments"] # Yield the delta @@ -166,7 +194,9 @@ def coding_llm(messages): else: if interpreter.debug_mode: - print("GOT BAD FUNCTION CALL: ", accumulated_deltas["function_call"]) + print( + "GOT BAD FUNCTION CALL: ", + accumulated_deltas["function_call"], + ) - - return coding_llm \ No newline at end of file + return coding_llm diff --git a/interpreter/llm/setup_text_llm.py b/interpreter/llm/setup_text_llm.py index f84958e141..0bdd74fead 100644 --- a/interpreter/llm/setup_text_llm.py +++ b/interpreter/llm/setup_text_llm.py @@ -1,12 +1,12 @@ - +import os +import traceback import litellm +import tokentrim as tt from ..utils.display_markdown_message import display_markdown_message from .setup_local_text_llm import setup_local_text_llm -import os -import tokentrim as tt -import traceback + def setup_text_llm(interpreter): """ @@ -15,7 +15,6 @@ def setup_text_llm(interpreter): """ if interpreter.local: - # Soon, we should have more options for local setup. For now we only have HuggingFace. # So we just do that. @@ -27,14 +26,14 @@ def setup_text_llm(interpreter): # this gets set up in the terminal interface / validate LLM settings. # then that's passed into this: return setup_local_text_llm(interpreter) - + # If we're here, it means the user wants to use # an OpenAI compatible endpoint running on localhost if interpreter.api_base is None: raise Exception('''To use Open Interpreter locally, either provide a huggingface model via `interpreter --model huggingface/{huggingface repo name}` or a localhost URL that exposes an OpenAI compatible endpoint by setting `interpreter --api_base {localhost URL}`.''') - + # Tell LiteLLM to treat the endpoint as an OpenAI proxy model = "custom_openai/" + interpreter.model @@ -47,13 +46,17 @@ def setup_text_llm(interpreter): traceback.print_exc() # If it didn't work, apologize and switch to GPT-4 - display_markdown_message(f""" + display_markdown_message( + f""" > Failed to install `{interpreter.model}`. \n\n**We have likely not built the proper `{interpreter.model}` support for your system.** \n\n(*Running language models locally is a difficult task!* If you have insight into the best way to implement this across platforms/architectures, please join the `Open Interpreter` community Discord, or the `Oobabooga` community Discord, and consider contributing the development of these projects.) - """) - - raise Exception("Architecture not yet supported for local LLM inference via `Oobabooga`. Please run `interpreter` to connect to a cloud model.") + """ + ) + + raise Exception( + "Architecture not yet supported for local LLM inference via `Oobabooga`. Please run `interpreter` to connect to a cloud model." + ) # Pass remaining parameters to LiteLLM def base_llm(messages): @@ -71,27 +74,39 @@ def base_llm(messages): # TODO swap tt.trim for litellm util messages = messages[1:] if interpreter.context_window and interpreter.max_tokens: - trim_to_be_this_many_tokens = interpreter.context_window - interpreter.max_tokens - 25 # arbitrary buffer - messages = tt.trim(messages, system_message=system_message, max_tokens=trim_to_be_this_many_tokens) + trim_to_be_this_many_tokens = ( + interpreter.context_window - interpreter.max_tokens - 25 + ) # arbitrary buffer + messages = tt.trim( + messages, + system_message=system_message, + max_tokens=trim_to_be_this_many_tokens, + ) else: try: - messages = tt.trim(messages, system_message=system_message, model=interpreter.model) + messages = tt.trim( + messages, system_message=system_message, model=interpreter.model + ) except: - display_markdown_message(""" + display_markdown_message( + """ **We were unable to determine the context window of this model.** Defaulting to 3000. If your model can handle more, run `interpreter --context_window {token limit}` or `interpreter.context_window = {token limit}`. Also, please set max_tokens: `interpreter --max_tokens {max tokens per response}` or `interpreter.max_tokens = {max tokens per response}` - """) - messages = tt.trim(messages, system_message=system_message, max_tokens=3000) + """ + ) + messages = tt.trim( + messages, system_message=system_message, max_tokens=3000 + ) if interpreter.debug_mode: print("Passing messages into LLM:", messages) - + # Create LiteLLM generator params = { - 'model': interpreter.model, - 'messages': messages, - 'stream': True, + "model": interpreter.model, + "messages": messages, + "stream": True, } # Optional inputs diff --git a/interpreter/rag/get_relevant_procedures_string.py b/interpreter/rag/get_relevant_procedures_string.py index d2cc7f9538..893237564f 100644 --- a/interpreter/rag/get_relevant_procedures_string.py +++ b/interpreter/rag/get_relevant_procedures_string.py @@ -1,15 +1,20 @@ import requests + from ..utils.vector_search import search -def get_relevant_procedures_string(interpreter): +def get_relevant_procedures_string(interpreter): # Open Procedures is an open-source database of tiny, up-to-date coding tutorials. # We can query it semantically and append relevant tutorials/procedures to our system message # If download_open_procedures is True and interpreter.procedures is None, # We download the bank of procedures: - if interpreter.procedures is None and interpreter.download_open_procedures and not interpreter.local: + if ( + interpreter.procedures is None + and interpreter.download_open_procedures + and not interpreter.local + ): # Let's get Open Procedures from Github url = "https://raw.githubusercontent.com/KillianLucas/open-procedures/main/procedures_db.json" response = requests.get(url) @@ -40,12 +45,21 @@ def get_relevant_procedures_string(interpreter): num_results = interpreter.num_procedures - relevant_procedures = search(query_string, interpreter._procedures_db, interpreter.embed_function, num_results=num_results) + relevant_procedures = search( + query_string, + interpreter._procedures_db, + interpreter.embed_function, + num_results=num_results, + ) # This can be done better. Some procedures should just be "sticky"... - relevant_procedures_string = "[Recommended Procedures]\n" + "\n---\n".join(relevant_procedures) + "\nIn your plan, include steps and, if present, **EXACT CODE SNIPPETS** (especially for deprecation notices, **WRITE THEM INTO YOUR PLAN -- underneath each numbered step** as they will VANISH once you execute your first line of code, so WRITE THEM DOWN NOW if you need them) from the above procedures if they are relevant to the task. Again, include **VERBATIM CODE SNIPPETS** from the procedures above if they are relevent to the task **directly in your plan.**" + relevant_procedures_string = ( + "[Recommended Procedures]\n" + + "\n---\n".join(relevant_procedures) + + "\nIn your plan, include steps and, if present, **EXACT CODE SNIPPETS** (especially for deprecation notices, **WRITE THEM INTO YOUR PLAN -- underneath each numbered step** as they will VANISH once you execute your first line of code, so WRITE THEM DOWN NOW if you need them) from the above procedures if they are relevant to the task. Again, include **VERBATIM CODE SNIPPETS** from the procedures above if they are relevent to the task **directly in your plan.**" + ) if interpreter.debug_mode: print("Generated relevant_procedures_string:", relevant_procedures_string) - return relevant_procedures_string \ No newline at end of file + return relevant_procedures_string diff --git a/interpreter/terminal_interface/components/base_block.py b/interpreter/terminal_interface/components/base_block.py index 936e11e706..89578f5b48 100644 --- a/interpreter/terminal_interface/components/base_block.py +++ b/interpreter/terminal_interface/components/base_block.py @@ -1,12 +1,16 @@ -from rich.live import Live from rich.console import Console +from rich.live import Live + class BaseBlock: """ a visual "block" on the terminal. """ + def __init__(self): - self.live = Live(auto_refresh=False, console=Console(), vertical_overflow="visible") + self.live = Live( + auto_refresh=False, console=Console(), vertical_overflow="visible" + ) self.live.start() def update_from_message(self, message): @@ -17,4 +21,4 @@ def end(self): self.live.stop() def refresh(self, cursor=True): - raise NotImplementedError("Subclasses must implement this method") \ No newline at end of file + raise NotImplementedError("Subclasses must implement this method") diff --git a/interpreter/terminal_interface/components/code_block.py b/interpreter/terminal_interface/components/code_block.py index a7b18fee75..87082f5ce6 100644 --- a/interpreter/terminal_interface/components/code_block.py +++ b/interpreter/terminal_interface/components/code_block.py @@ -1,77 +1,81 @@ -from rich.panel import Panel from rich.box import MINIMAL +from rich.console import Group +from rich.panel import Panel from rich.syntax import Syntax from rich.table import Table -from rich.console import Group + from .base_block import BaseBlock -class CodeBlock(BaseBlock): - """ - Code Blocks display code and outputs in different languages. You can also set the active_line! - """ - def __init__(self): - super().__init__() +class CodeBlock(BaseBlock): + """ + Code Blocks display code and outputs in different languages. You can also set the active_line! + """ - self.type = "code" + def __init__(self): + super().__init__() - # Define these for IDE auto-completion - self.language = "" - self.output = "" - self.code = "" - self.active_line = None - self.margin_top = True + self.type = "code" - def refresh(self, cursor=True): - # Get code, return if there is none - code = self.code - if not code: - return - - # Create a table for the code - code_table = Table(show_header=False, - show_footer=False, - box=None, - padding=0, - expand=True) - code_table.add_column() + # Define these for IDE auto-completion + self.language = "" + self.output = "" + self.code = "" + self.active_line = None + self.margin_top = True - # Add cursor - if cursor: - code += "●" + def refresh(self, cursor=True): + # Get code, return if there is none + code = self.code + if not code: + return - # Add each line of code to the table - code_lines = code.strip().split('\n') - for i, line in enumerate(code_lines, start=1): - if i == self.active_line: - # This is the active line, print it with a white background - syntax = Syntax(line, self.language, theme="bw", line_numbers=False, word_wrap=True) - code_table.add_row(syntax, style="black on white") - else: - # This is not the active line, print it normally - syntax = Syntax(line, self.language, theme="monokai", line_numbers=False, word_wrap=True) - code_table.add_row(syntax) + # Create a table for the code + code_table = Table( + show_header=False, show_footer=False, box=None, padding=0, expand=True + ) + code_table.add_column() - # Create a panel for the code - code_panel = Panel(code_table, box=MINIMAL, style="on #272722") + # Add cursor + if cursor: + code += "●" - # Create a panel for the output (if there is any) - if self.output == "" or self.output == "None": - output_panel = "" - else: - output_panel = Panel(self.output, - box=MINIMAL, - style="#FFFFFF on #3b3b37") + # Add each line of code to the table + code_lines = code.strip().split("\n") + for i, line in enumerate(code_lines, start=1): + if i == self.active_line: + # This is the active line, print it with a white background + syntax = Syntax( + line, self.language, theme="bw", line_numbers=False, word_wrap=True + ) + code_table.add_row(syntax, style="black on white") + else: + # This is not the active line, print it normally + syntax = Syntax( + line, + self.language, + theme="monokai", + line_numbers=False, + word_wrap=True, + ) + code_table.add_row(syntax) - # Create a group with the code table and output panel - group_items = [code_panel, output_panel] - if self.margin_top: - # This adds some space at the top. Just looks good! - group_items = [""] + group_items - group = Group(*group_items) + # Create a panel for the code + code_panel = Panel(code_table, box=MINIMAL, style="on #272722") - # Update the live display - self.live.update(group) - self.live.refresh() + # Create a panel for the output (if there is any) + if self.output == "" or self.output == "None": + output_panel = "" + else: + output_panel = Panel(self.output, box=MINIMAL, style="#FFFFFF on #3b3b37") + # Create a group with the code table and output panel + group_items = [code_panel, output_panel] + if self.margin_top: + # This adds some space at the top. Just looks good! + group_items = [""] + group_items + group = Group(*group_items) + # Update the live display + self.live.update(group) + self.live.refresh() diff --git a/interpreter/terminal_interface/components/message_block.py b/interpreter/terminal_interface/components/message_block.py index 87ebce2458..5da980855d 100644 --- a/interpreter/terminal_interface/components/message_block.py +++ b/interpreter/terminal_interface/components/message_block.py @@ -1,48 +1,50 @@ -from rich.panel import Panel -from rich.markdown import Markdown -from rich.box import MINIMAL import re + +from rich.box import MINIMAL +from rich.markdown import Markdown +from rich.panel import Panel + from .base_block import BaseBlock + class MessageBlock(BaseBlock): + def __init__(self): + super().__init__() + + self.type = "message" + self.message = "" + self.has_run = False - def __init__(self): - super().__init__() + def refresh(self, cursor=True): + # De-stylize any code blocks in markdown, + # to differentiate from our Code Blocks + content = textify_markdown_code_blocks(self.message) - self.type = "message" - self.message = "" - self.has_run = False + if cursor: + content += "●" - def refresh(self, cursor=True): - # De-stylize any code blocks in markdown, - # to differentiate from our Code Blocks - content = textify_markdown_code_blocks(self.message) - - if cursor: - content += "●" - - markdown = Markdown(content.strip()) - panel = Panel(markdown, box=MINIMAL) - self.live.update(panel) - self.live.refresh() + markdown = Markdown(content.strip()) + panel = Panel(markdown, box=MINIMAL) + self.live.update(panel) + self.live.refresh() def textify_markdown_code_blocks(text): - """ - To distinguish CodeBlocks from markdown code, we simply turn all markdown code - (like '```python...') into text code blocks ('```text') which makes the code black and white. - """ - replacement = "```text" - lines = text.split('\n') - inside_code_block = False - - for i in range(len(lines)): - # If the line matches ``` followed by optional language specifier - if re.match(r'^```(\w*)$', lines[i].strip()): - inside_code_block = not inside_code_block - - # If we just entered a code block, replace the marker - if inside_code_block: - lines[i] = replacement - - return '\n'.join(lines) + """ + To distinguish CodeBlocks from markdown code, we simply turn all markdown code + (like '```python...') into text code blocks ('```text') which makes the code black and white. + """ + replacement = "```text" + lines = text.split("\n") + inside_code_block = False + + for i in range(len(lines)): + # If the line matches ``` followed by optional language specifier + if re.match(r"^```(\w*)$", lines[i].strip()): + inside_code_block = not inside_code_block + + # If we just entered a code block, replace the marker + if inside_code_block: + lines[i] = replacement + + return "\n".join(lines) diff --git a/interpreter/terminal_interface/conversation_navigator.py b/interpreter/terminal_interface/conversation_navigator.py index 6611426983..dae33c6f7a 100644 --- a/interpreter/terminal_interface/conversation_navigator.py +++ b/interpreter/terminal_interface/conversation_navigator.py @@ -2,23 +2,27 @@ This file handles conversations. """ -import inquirer -import subprocess -import platform -import os import json -from .render_past_conversation import render_past_conversation +import os +import platform +import subprocess + +import inquirer + from ..utils.display_markdown_message import display_markdown_message from ..utils.local_storage_path import get_storage_path +from .render_past_conversation import render_past_conversation -def conversation_navigator(interpreter): +def conversation_navigator(interpreter): conversations_dir = get_storage_path("conversations") - display_markdown_message(f"""> Conversations are stored in "`{conversations_dir}`". + display_markdown_message( + f"""> Conversations are stored in "`{conversations_dir}`". Select a conversation to resume. - """) + """ + ) # Check if conversations directory exists if not os.path.exists(conversations_dir): @@ -26,12 +30,18 @@ def conversation_navigator(interpreter): return None # Get list of all JSON files in the directory - json_files = [f for f in os.listdir(conversations_dir) if f.endswith('.json')] + json_files = [f for f in os.listdir(conversations_dir) if f.endswith(".json")] # Make a dict that maps reformatted "First few words... (September 23rd)" -> "First_few_words__September_23rd.json" (original file name) readable_names_and_filenames = {} for filename in json_files: - name = filename.replace(".json", "").replace(".JSON", "").replace("__", "... (").replace("_", " ") + ")" + name = ( + filename.replace(".json", "") + .replace(".JSON", "") + .replace("__", "... (") + .replace("_", " ") + + ")" + ) readable_names_and_filenames[name] = filename # Add the option to open the folder. This doesn't map to a filename, we'll catch it @@ -39,22 +49,23 @@ def conversation_navigator(interpreter): # Use inquirer to let the user select a file questions = [ - inquirer.List('name', - message="", - choices=readable_names_and_filenames.keys(), - ), + inquirer.List( + "name", + message="", + choices=readable_names_and_filenames.keys(), + ), ] answers = inquirer.prompt(questions) # If the user selected to open the folder, do so and return - if answers['name'] == "> Open folder": + if answers["name"] == "> Open folder": open_folder(conversations_dir) return - selected_filename = readable_names_and_filenames[answers['name']] + selected_filename = readable_names_and_filenames[answers["name"]] # Open the selected file and load the JSON data - with open(os.path.join(conversations_dir, selected_filename), 'r') as f: + with open(os.path.join(conversations_dir, selected_filename), "r") as f: messages = json.load(f) # Pass the data into render_past_conversation @@ -67,6 +78,7 @@ def conversation_navigator(interpreter): # Start the chat interpreter.chat() + def open_folder(path): if platform.system() == "Windows": os.startfile(path) @@ -74,4 +86,4 @@ def open_folder(path): subprocess.run(["open", path]) else: # Assuming it's Linux - subprocess.run(["xdg-open", path]) \ No newline at end of file + subprocess.run(["xdg-open", path]) diff --git a/interpreter/terminal_interface/magic_commands.py b/interpreter/terminal_interface/magic_commands.py index 7c838babe2..9db0a6b0ac 100644 --- a/interpreter/terminal_interface/magic_commands.py +++ b/interpreter/terminal_interface/magic_commands.py @@ -1,19 +1,21 @@ -from ..utils.display_markdown_message import display_markdown_message -from ..utils.count_tokens import count_messages_tokens import json import os +from ..utils.count_tokens import count_messages_tokens +from ..utils.display_markdown_message import display_markdown_message + + def handle_undo(self, arguments): # Removes all messages after the most recent user entry (and the entry itself). # Therefore user can jump back to the latest point of conversation. # Also gives a visual representation of the messages removed. if len(self.messages) == 0: - return + return # Find the index of the last 'role': 'user' entry last_user_index = None for i, message in enumerate(self.messages): - if message.get('role') == 'user': + if message.get("role") == "user": last_user_index = i removed_messages = [] @@ -23,38 +25,41 @@ def handle_undo(self, arguments): removed_messages = self.messages[last_user_index:] self.messages = self.messages[:last_user_index] - print("") # Aesthetics. + print("") # Aesthetics. # Print out a preview of what messages were removed. for message in removed_messages: - if 'content' in message and message['content'] != None: - display_markdown_message(f"**Removed message:** `\"{message['content'][:30]}...\"`") - elif 'function_call' in message: - display_markdown_message(f"**Removed codeblock**") # TODO: Could add preview of code removed here. - - print("") # Aesthetics. + if "content" in message and message["content"] != None: + display_markdown_message( + f"**Removed message:** `\"{message['content'][:30]}...\"`" + ) + elif "function_call" in message: + display_markdown_message( + f"**Removed codeblock**" + ) # TODO: Could add preview of code removed here. + + print("") # Aesthetics. + def handle_help(self, arguments): commands_description = { - "%debug [true/false]": "Toggle debug mode. Without arguments or with 'true', it enters debug mode. With 'false', it exits debug mode.", - "%reset": "Resets the current session.", - "%undo": "Remove previous messages and its response from the message history.", - "%save_message [path]": "Saves messages to a specified JSON path. If no path is provided, it defaults to 'messages.json'.", - "%load_message [path]": "Loads messages from a specified JSON path. If no path is provided, it defaults to 'messages.json'.", - "%tokens [prompt]": "EXPERIMENTAL: Calculate the tokens used by the next request based on the current conversation's messages and estimate the cost of that request; optionally provide a prompt to also calulate the tokens used by that prompt and the total amount of tokens that will be sent with the next request", - "%help": "Show this help message.", + "%debug [true/false]": "Toggle debug mode. Without arguments or with 'true', it enters debug mode. With 'false', it exits debug mode.", + "%reset": "Resets the current session.", + "%undo": "Remove previous messages and its response from the message history.", + "%save_message [path]": "Saves messages to a specified JSON path. If no path is provided, it defaults to 'messages.json'.", + "%load_message [path]": "Loads messages from a specified JSON path. If no path is provided, it defaults to 'messages.json'.", + "%tokens [prompt]": "EXPERIMENTAL: Calculate the tokens used by the next request based on the current conversation's messages and estimate the cost of that request; optionally provide a prompt to also calulate the tokens used by that prompt and the total amount of tokens that will be sent with the next request", + "%help": "Show this help message.", } - base_message = [ - "> **Available Commands:**\n\n" - ] + base_message = ["> **Available Commands:**\n\n"] # Add each command and its description to the message for cmd, desc in commands_description.items(): - base_message.append(f"- `{cmd}`: {desc}\n") + base_message.append(f"- `{cmd}`: {desc}\n") additional_info = [ - "\n\nFor further assistance, please join our community Discord or consider contributing to the project's development." + "\n\nFor further assistance, please join our community Discord or consider contributing to the project's development." ] # Combine the base message with the additional info @@ -74,33 +79,40 @@ def handle_debug(self, arguments=None): else: display_markdown_message("> Unknown argument to debug command.") + def handle_reset(self, arguments): self.reset() display_markdown_message("> Reset Done") + def default_handle(self, arguments): display_markdown_message("> Unknown command") - handle_help(self,arguments) + handle_help(self, arguments) + def handle_save_message(self, json_path): if json_path == "": - json_path = "messages.json" + json_path = "messages.json" if not json_path.endswith(".json"): - json_path += ".json" - with open(json_path, 'w') as f: - json.dump(self.messages, f, indent=2) + json_path += ".json" + with open(json_path, "w") as f: + json.dump(self.messages, f, indent=2) display_markdown_message(f"> messages json export to {os.path.abspath(json_path)}") + def handle_load_message(self, json_path): if json_path == "": - json_path = "messages.json" + json_path = "messages.json" if not json_path.endswith(".json"): - json_path += ".json" - with open(json_path, 'r') as f: - self.messages = json.load(f) + json_path += ".json" + with open(json_path, "r") as f: + self.messages = json.load(f) + + display_markdown_message( + f"> messages json loaded from {os.path.abspath(json_path)}" + ) - display_markdown_message(f"> messages json loaded from {os.path.abspath(json_path)}") def handle_count_tokens(self, prompt): messages = [{"role": "system", "message": self.system_message}] + self.messages @@ -108,22 +120,38 @@ def handle_count_tokens(self, prompt): outputs = [] if len(self.messages) == 0: - (conversation_tokens, conversation_cost) = count_messages_tokens(messages=messages, model=self.model) + (conversation_tokens, conversation_cost) = count_messages_tokens( + messages=messages, model=self.model + ) else: - (conversation_tokens, conversation_cost) = count_messages_tokens(messages=messages, model=self.model) + (conversation_tokens, conversation_cost) = count_messages_tokens( + messages=messages, model=self.model + ) - outputs.append((f"> Tokens sent with next request as context: {conversation_tokens} (Estimated Cost: ${conversation_cost})")) + outputs.append( + ( + f"> Tokens sent with next request as context: {conversation_tokens} (Estimated Cost: ${conversation_cost})" + ) + ) if prompt: - (prompt_tokens, prompt_cost) = count_messages_tokens(messages=[prompt], model=self.model) - outputs.append(f"> Tokens used by this prompt: {prompt_tokens} (Estimated Cost: ${prompt_cost})") + (prompt_tokens, prompt_cost) = count_messages_tokens( + messages=[prompt], model=self.model + ) + outputs.append( + f"> Tokens used by this prompt: {prompt_tokens} (Estimated Cost: ${prompt_cost})" + ) - total_tokens = conversation_tokens + prompt_tokens - total_cost = conversation_cost + prompt_cost + total_tokens = conversation_tokens + prompt_tokens + total_cost = conversation_cost + prompt_cost - outputs.append(f"> Total tokens for next request with this prompt: {total_tokens} (Estimated Cost: ${total_cost})") + outputs.append( + f"> Total tokens for next request with this prompt: {total_tokens} (Estimated Cost: ${total_cost})" + ) - outputs.append(f"**Note**: This functionality is currently experimental and may not be accurate. Please report any issues you find to the [Open Interpreter GitHub repository](https://github.com/KillianLucas/open-interpreter).") + outputs.append( + f"**Note**: This functionality is currently experimental and may not be accurate. Please report any issues you find to the [Open Interpreter GitHub repository](https://github.com/KillianLucas/open-interpreter)." + ) display_markdown_message("\n".join(outputs)) @@ -131,17 +159,19 @@ def handle_count_tokens(self, prompt): def handle_magic_command(self, user_input): # split the command into the command and the arguments, by the first whitespace switch = { - "help": handle_help, - "debug": handle_debug, - "reset": handle_reset, - "save_message": handle_save_message, - "load_message": handle_load_message, - "undo": handle_undo, - "tokens": handle_count_tokens, + "help": handle_help, + "debug": handle_debug, + "reset": handle_reset, + "save_message": handle_save_message, + "load_message": handle_load_message, + "undo": handle_undo, + "tokens": handle_count_tokens, } user_input = user_input[1:].strip() # Capture the part after the `%` command = user_input.split(" ")[0] - arguments = user_input[len(command):].strip() - action = switch.get(command, default_handle) # Get the function from the dictionary, or default_handle if not found - action(self, arguments) # Execute the function + arguments = user_input[len(command) :].strip() + action = switch.get( + command, default_handle + ) # Get the function from the dictionary, or default_handle if not found + action(self, arguments) # Execute the function diff --git a/interpreter/terminal_interface/render_past_conversation.py b/interpreter/terminal_interface/render_past_conversation.py index a9a0fa4dec..8e98e90b25 100644 --- a/interpreter/terminal_interface/render_past_conversation.py +++ b/interpreter/terminal_interface/render_past_conversation.py @@ -1,7 +1,8 @@ +from ..utils.display_markdown_message import display_markdown_message from .components.code_block import CodeBlock from .components.message_block import MessageBlock from .magic_commands import handle_magic_command -from ..utils.display_markdown_message import display_markdown_message + def render_past_conversation(messages): # This is a clone of the terminal interface. @@ -40,7 +41,7 @@ def render_past_conversation(messages): active_block = CodeBlock() ran_code_block = False render_cursor = True - + if "language" in chunk: active_block.language = chunk["language"] if "code" in chunk: @@ -53,7 +54,7 @@ def render_past_conversation(messages): ran_code_block = True render_cursor = False active_block.output += "\n" + chunk["output"] - active_block.output = active_block.output.strip() # <- Aesthetic choice + active_block.output = active_block.output.strip() # <- Aesthetic choice if active_block: active_block.refresh(cursor=render_cursor) @@ -61,4 +62,4 @@ def render_past_conversation(messages): # (Sometimes -- like if they CTRL-C quickly -- active_block is still None here) if active_block: active_block.end() - active_block = None \ No newline at end of file + active_block = None diff --git a/interpreter/terminal_interface/terminal_interface.py b/interpreter/terminal_interface/terminal_interface.py index c3a2c54ec2..4d820bb93d 100644 --- a/interpreter/terminal_interface/terminal_interface.py +++ b/interpreter/terminal_interface/terminal_interface.py @@ -8,13 +8,13 @@ except ImportError: pass +from ..utils.check_for_package import check_for_package +from ..utils.display_markdown_message import display_markdown_message +from ..utils.scan_code import scan_code +from ..utils.truncate_output import truncate_output from .components.code_block import CodeBlock from .components.message_block import MessageBlock from .magic_commands import handle_magic_command -from ..utils.display_markdown_message import display_markdown_message -from ..utils.truncate_output import truncate_output -from ..utils.scan_code import scan_code -from ..utils.check_for_package import check_for_package def terminal_interface(interpreter, message): @@ -25,16 +25,16 @@ def terminal_interface(interpreter, message): if interpreter.safe_mode == "ask" or interpreter.safe_mode == "auto": if not check_for_package("semgrep"): - interpreter_intro_message.append(f"**Safe Mode**: {interpreter.safe_mode}\n\n>Note: **Safe Mode** requires `semgrep` (`pip install semgrep`)") + interpreter_intro_message.append( + f"**Safe Mode**: {interpreter.safe_mode}\n\n>Note: **Safe Mode** requires `semgrep` (`pip install semgrep`)" + ) else: - interpreter_intro_message.append( - "Use `interpreter -y` to bypass this." - ) + interpreter_intro_message.append("Use `interpreter -y` to bypass this.") interpreter_intro_message.append("Press `CTRL-C` to exit.") display_markdown_message("\n\n".join(interpreter_intro_message) + "\n") - + active_block = None if message: @@ -64,12 +64,12 @@ def terminal_interface(interpreter, message): # In the event we get code -> output -> code again ran_code_block = False render_cursor = True - + try: for chunk in interpreter.chat(message, display=False, stream=True): if interpreter.debug_mode: print("Chunk in `terminal_interface`:", chunk) - + # Message if "message" in chunk: if active_block is None: @@ -91,7 +91,7 @@ def terminal_interface(interpreter, message): active_block = CodeBlock() ran_code_block = False render_cursor = True - + if "language" in chunk: active_block.language = chunk["language"] if "code" in chunk: @@ -112,8 +112,10 @@ def terminal_interface(interpreter, message): if not interpreter.safe_mode == "off": if interpreter.safe_mode == "auto": should_scan_code = True - elif interpreter.safe_mode == 'ask': - response = input(" Would you like to scan this code? (y/n)\n\n ") + elif interpreter.safe_mode == "ask": + response = input( + " Would you like to scan this code? (y/n)\n\n " + ) print("") # <- Aesthetic choice if response.strip().lower() == "y": @@ -127,22 +129,26 @@ def terminal_interface(interpreter, message): scan_code(code, language, interpreter) - response = input(" Would you like to run this code? (y/n)\n\n ") + response = input( + " Would you like to run this code? (y/n)\n\n " + ) print("") # <- Aesthetic choice if response.strip().lower() == "y": # Create a new, identical block where the code will actually be run # Conveniently, the chunk includes everything we need to do this: active_block = CodeBlock() - active_block.margin_top = False # <- Aesthetic choice + active_block.margin_top = False # <- Aesthetic choice active_block.language = chunk["executing"]["language"] active_block.code = chunk["executing"]["code"] else: # User declined to run code. - interpreter.messages.append({ - "role": "user", - "message": "I have declined to run this code." - }) + interpreter.messages.append( + { + "role": "user", + "message": "I have declined to run this code.", + } + ) break # Output @@ -150,10 +156,14 @@ def terminal_interface(interpreter, message): ran_code_block = True render_cursor = False active_block.output += "\n" + chunk["output"] - active_block.output = active_block.output.strip() # <- Aesthetic choice - + active_block.output = ( + active_block.output.strip() + ) # <- Aesthetic choice + # Truncate output - active_block.output = truncate_output(active_block.output, interpreter.max_output) + active_block.output = truncate_output( + active_block.output, interpreter.max_output + ) if active_block: active_block.refresh(cursor=render_cursor) @@ -174,9 +184,9 @@ def terminal_interface(interpreter, message): if active_block: active_block.end() active_block = None - + if interactive: # (this cancels LLM, returns to the interactive "> " input) continue else: - break \ No newline at end of file + break diff --git a/interpreter/terminal_interface/validate_llm_settings.py b/interpreter/terminal_interface/validate_llm_settings.py index 9a22654a2b..79923f0867 100644 --- a/interpreter/terminal_interface/validate_llm_settings.py +++ b/interpreter/terminal_interface/validate_llm_settings.py @@ -1,9 +1,12 @@ +import getpass import os -from ..utils.display_markdown_message import display_markdown_message import time + import inquirer import litellm -import getpass + +from ..utils.display_markdown_message import display_markdown_message + def validate_llm_settings(interpreter): """ @@ -13,16 +16,16 @@ def validate_llm_settings(interpreter): # This runs in a while loop so `continue` lets us start from the top # after changing settings (like switching to/from local) while True: - if interpreter.local: # Ensure model is downloaded and ready to be set up if interpreter.model == "": - # Interactive prompt to download the best local model we know of - display_markdown_message(""" - **Open Interpreter** will use `Mistral 7B` for local execution.""") + display_markdown_message( + """ + **Open Interpreter** will use `Mistral 7B` for local execution.""" + ) if interpreter.gguf_quality == None: interpreter.gguf_quality = 0.35 @@ -43,26 +46,25 @@ def validate_llm_settings(interpreter): """ interpreter.model = "huggingface/TheBloke/Mistral-7B-Instruct-v0.1-GGUF" - + break else: - # They have selected a model. Have they downloaded it? # Break here because currently, this is handled in llm/setup_local_text_llm.py # How should we unify all this? break - + else: # Ensure API keys are set as environment variables # OpenAI if interpreter.model in litellm.open_ai_chat_completion_models: if not os.environ.get("OPENAI_API_KEY") and not interpreter.api_key: - display_welcome_message_once() - display_markdown_message("""--- + display_markdown_message( + """--- > OpenAI API key not found To use `GPT-4` (recommended) please provide an OpenAI API key. @@ -70,28 +72,33 @@ def validate_llm_settings(interpreter): To use `Mistral-7B` (free but less capable) press `enter`. --- - """) + """ + ) response = getpass.getpass("OpenAI API key: ") print(f"OpenAI API key: {response[:4]}...{response[-4:]}") if response == "": # User pressed `enter`, requesting Mistral-7B - display_markdown_message("""> Switching to `Mistral-7B`... + display_markdown_message( + """> Switching to `Mistral-7B`... **Tip:** Run `interpreter --local` to automatically use `Mistral-7B`. - ---""") + ---""" + ) time.sleep(1.5) interpreter.local = True interpreter.model = "" continue - - display_markdown_message(""" + + display_markdown_message( + """ **Tip:** To save this key for later, run `export OPENAI_API_KEY=your_api_key` on Mac/Linux or `setx OPENAI_API_KEY your_api_key` on Windows. - ---""") + ---""" + ) interpreter.api_key = response time.sleep(2) @@ -112,16 +119,17 @@ def validate_llm_settings(interpreter): def display_welcome_message_once(): """ Displays a welcome message only on its first call. - + (Uses an internal attribute `_displayed` to track its state.) """ if not hasattr(display_welcome_message_once, "_displayed"): - - display_markdown_message(""" + display_markdown_message( + """ ● Welcome to **Open Interpreter**. - """) + """ + ) time.sleep(1.5) - display_welcome_message_once._displayed = True \ No newline at end of file + display_welcome_message_once._displayed = True diff --git a/interpreter/utils/check_for_package.py b/interpreter/utils/check_for_package.py index 159bb37346..d7bde1bf5b 100644 --- a/interpreter/utils/check_for_package.py +++ b/interpreter/utils/check_for_package.py @@ -1,19 +1,20 @@ import importlib.util import sys -#borrowed from: https://stackoverflow.com/a/1051266/656011 + +# borrowed from: https://stackoverflow.com/a/1051266/656011 def check_for_package(package): - if package in sys.modules: - return True - elif (spec := importlib.util.find_spec(package)) is not None: - try: - module = importlib.util.module_from_spec(spec) + if package in sys.modules: + return True + elif (spec := importlib.util.find_spec(package)) is not None: + try: + module = importlib.util.module_from_spec(spec) - sys.modules[package] = module - spec.loader.exec_module(module) + sys.modules[package] = module + spec.loader.exec_module(module) - return True - except ImportError: - return False - else: - return False \ No newline at end of file + return True + except ImportError: + return False + else: + return False diff --git a/interpreter/utils/check_for_update.py b/interpreter/utils/check_for_update.py index 57b2c33135..d8fbb79e43 100644 --- a/interpreter/utils/check_for_update.py +++ b/interpreter/utils/check_for_update.py @@ -1,13 +1,14 @@ -import requests import pkg_resources +import requests from packaging import version + def check_for_update(): # Fetch the latest version from the PyPI API - response = requests.get(f'https://pypi.org/pypi/open-interpreter/json') - latest_version = response.json()['info']['version'] + response = requests.get(f"https://pypi.org/pypi/open-interpreter/json") + latest_version = response.json()["info"]["version"] # Get the current version using pkg_resources current_version = pkg_resources.get_distribution("open-interpreter").version - return version.parse(latest_version) > version.parse(current_version) \ No newline at end of file + return version.parse(latest_version) > version.parse(current_version) diff --git a/interpreter/utils/convert_to_openai_messages.py b/interpreter/utils/convert_to_openai_messages.py index 5e0d2ff873..2fcedc2f3e 100644 --- a/interpreter/utils/convert_to_openai_messages.py +++ b/interpreter/utils/convert_to_openai_messages.py @@ -1,13 +1,11 @@ import json + def convert_to_openai_messages(messages, function_calling=True): new_messages = [] - for message in messages: - new_message = { - "role": message["role"], - "content": "" - } + for message in messages: + new_message = {"role": message["role"], "content": ""} if "message" in message: new_message["content"] = message["message"] @@ -16,34 +14,40 @@ def convert_to_openai_messages(messages, function_calling=True): if function_calling: new_message["function_call"] = { "name": "execute", - "arguments": json.dumps({ - "language": message["language"], - "code": message["code"] - }), + "arguments": json.dumps( + {"language": message["language"], "code": message["code"]} + ), # parsed_arguments isn't actually an OpenAI thing, it's an OI thing. # but it's soo useful! we use it to render messages to text_llms "parsed_arguments": { "language": message["language"], - "code": message["code"] - } + "code": message["code"], + }, } else: - new_message["content"] += f"""\n\n```{message["language"]}\n{message["code"]}\n```""" + new_message[ + "content" + ] += f"""\n\n```{message["language"]}\n{message["code"]}\n```""" new_message["content"] = new_message["content"].strip() new_messages.append(new_message) if "output" in message: if function_calling: - new_messages.append({ - "role": "function", - "name": "execute", - "content": message["output"] - }) + new_messages.append( + { + "role": "function", + "name": "execute", + "content": message["output"], + } + ) else: - new_messages.append({ - "role": "user", - "content": "CODE EXECUTED ON USERS MACHINE. OUTPUT (invisible to the user): " + message["output"] - }) + new_messages.append( + { + "role": "user", + "content": "CODE EXECUTED ON USERS MACHINE. OUTPUT (invisible to the user): " + + message["output"], + } + ) - return new_messages \ No newline at end of file + return new_messages diff --git a/interpreter/utils/count_tokens.py b/interpreter/utils/count_tokens.py index bda66a325b..19e38d41e8 100644 --- a/interpreter/utils/count_tokens.py +++ b/interpreter/utils/count_tokens.py @@ -1,6 +1,7 @@ import tiktoken from litellm import cost_per_token + def count_tokens(text="", model="gpt-4"): """ Count the number of tokens in a string @@ -10,6 +11,7 @@ def count_tokens(text="", model="gpt-4"): return len(encoder.encode(text)) + def token_cost(tokens=0, model="gpt-4"): """ Calculate the cost of the current number of tokens @@ -19,6 +21,7 @@ def token_cost(tokens=0, model="gpt-4"): return round(prompt_cost, 6) + def count_messages_tokens(messages=[], model=None): """ Count the number of tokens in a list of messages @@ -41,4 +44,3 @@ def count_messages_tokens(messages=[], model=None): prompt_cost = token_cost(tokens_used, model=model) return (tokens_used, prompt_cost) - diff --git a/interpreter/utils/display_markdown_message.py b/interpreter/utils/display_markdown_message.py index 9d06d2bbd5..08f4bf370e 100644 --- a/interpreter/utils/display_markdown_message.py +++ b/interpreter/utils/display_markdown_message.py @@ -2,6 +2,7 @@ from rich.markdown import Markdown from rich.rule import Rule + def display_markdown_message(message): """ Display markdown message. Works with multiline strings with lots of indentation. @@ -19,4 +20,4 @@ def display_markdown_message(message): if "\n" not in message and message.startswith(">"): # Aesthetic choice. For these tags, they need a space below them - print("") \ No newline at end of file + print("") diff --git a/interpreter/utils/embed.py b/interpreter/utils/embed.py index eb8f4f9d2a..d98e6b03a9 100644 --- a/interpreter/utils/embed.py +++ b/interpreter/utils/embed.py @@ -1,9 +1,12 @@ -from chromadb.utils.embedding_functions import DefaultEmbeddingFunction as setup_embed import os + import numpy as np +from chromadb.utils.embedding_functions import DefaultEmbeddingFunction as setup_embed # Set up the embedding function -os.environ["TOKENIZERS_PARALLELISM"] = "false" # Otherwise setup_embed displays a warning message +os.environ[ + "TOKENIZERS_PARALLELISM" +] = "false" # Otherwise setup_embed displays a warning message try: chroma_embedding_function = setup_embed() except: @@ -11,5 +14,6 @@ # If it fails, it's not worth breaking everything. pass + def embed_function(query): - return np.squeeze(chroma_embedding_function([query])).tolist() \ No newline at end of file + return np.squeeze(chroma_embedding_function([query])).tolist() diff --git a/interpreter/utils/get_config.py b/interpreter/utils/get_config.py index 558726b27b..ceb897e70d 100644 --- a/interpreter/utils/get_config.py +++ b/interpreter/utils/get_config.py @@ -1,7 +1,8 @@ import os -import yaml -from importlib import resources import shutil +from importlib import resources + +import yaml from .local_storage_path import get_storage_path @@ -9,6 +10,7 @@ user_config_path = os.path.join(get_storage_path(), config_filename) + def get_config_path(path=user_config_path): # check to see if we were given a path that exists if not os.path.exists(path): @@ -28,23 +30,23 @@ def get_config_path(path=user_config_path): else: # Ensure the user-specific directory exists os.makedirs(get_storage_path(), exist_ok=True) - + # otherwise, we'll create the file in our default config directory path = os.path.join(get_storage_path(), path) - # If user's config doesn't exist, copy the default config from the package here = os.path.abspath(os.path.dirname(__file__)) parent_dir = os.path.dirname(here) - default_config_path = os.path.join(parent_dir, 'config.yaml') + default_config_path = os.path.join(parent_dir, "config.yaml") # Copying the file using shutil.copy new_file = shutil.copy(default_config_path, path) return path + def get_config(path=user_config_path): path = get_config_path(path) - with open(path, 'r') as file: - return yaml.safe_load(file) \ No newline at end of file + with open(path, "r") as file: + return yaml.safe_load(file) diff --git a/interpreter/utils/get_conversations.py b/interpreter/utils/get_conversations.py index 43375a065b..036380a446 100644 --- a/interpreter/utils/get_conversations.py +++ b/interpreter/utils/get_conversations.py @@ -2,7 +2,8 @@ from ..utils.local_storage_path import get_storage_path + def get_conversations(): conversations_dir = get_storage_path("conversations") - json_files = [f for f in os.listdir(conversations_dir) if f.endswith('.json')] - return json_files \ No newline at end of file + json_files = [f for f in os.listdir(conversations_dir) if f.endswith(".json")] + return json_files diff --git a/interpreter/utils/get_local_models_paths.py b/interpreter/utils/get_local_models_paths.py index f7764d8020..a8a23874d5 100644 --- a/interpreter/utils/get_local_models_paths.py +++ b/interpreter/utils/get_local_models_paths.py @@ -2,7 +2,8 @@ from ..utils.local_storage_path import get_storage_path + def get_local_models_paths(): models_dir = get_storage_path("models") files = [os.path.join(models_dir, f) for f in os.listdir(models_dir)] - return files \ No newline at end of file + return files diff --git a/interpreter/utils/get_user_info_string.py b/interpreter/utils/get_user_info_string.py index d6d2fda712..d2239de118 100644 --- a/interpreter/utils/get_user_info_string.py +++ b/interpreter/utils/get_user_info_string.py @@ -2,11 +2,11 @@ import os import platform -def get_user_info_string(): +def get_user_info_string(): username = getpass.getuser() current_working_directory = os.getcwd() operating_system = platform.system() - default_shell = os.environ.get('SHELL') + default_shell = os.environ.get("SHELL") - return f"[User Info]\nName: {username}\nCWD: {current_working_directory}\nSHELL: {default_shell}\nOS: {operating_system}" \ No newline at end of file + return f"[User Info]\nName: {username}\nCWD: {current_working_directory}\nSHELL: {default_shell}\nOS: {operating_system}" diff --git a/interpreter/utils/local_storage_path.py b/interpreter/utils/local_storage_path.py index a4540b1116..502c9d169c 100644 --- a/interpreter/utils/local_storage_path.py +++ b/interpreter/utils/local_storage_path.py @@ -1,9 +1,11 @@ import os + import appdirs # Using appdirs to determine user-specific config path config_dir = appdirs.user_config_dir("Open Interpreter") + def get_storage_path(subdirectory=None): if subdirectory is None: return config_dir diff --git a/interpreter/utils/merge_deltas.py b/interpreter/utils/merge_deltas.py index 8ea7853658..71be04d424 100644 --- a/interpreter/utils/merge_deltas.py +++ b/interpreter/utils/merge_deltas.py @@ -1,6 +1,7 @@ import json import re + def merge_deltas(original, delta): """ Pushes the delta into the original and returns that. @@ -18,4 +19,4 @@ def merge_deltas(original, delta): original[key] += value else: original[key] = value - return original \ No newline at end of file + return original diff --git a/interpreter/utils/parse_partial_json.py b/interpreter/utils/parse_partial_json.py index d80e3cba37..f66117f63b 100644 --- a/interpreter/utils/parse_partial_json.py +++ b/interpreter/utils/parse_partial_json.py @@ -1,14 +1,14 @@ import json import re -def parse_partial_json(s): +def parse_partial_json(s): # Attempt to parse the string as-is. try: return json.loads(s) except json.JSONDecodeError: pass - + # Initialize variables. new_s = "" stack = [] @@ -20,9 +20,9 @@ def parse_partial_json(s): if is_inside_string: if char == '"' and not escaped: is_inside_string = False - elif char == '\n' and not escaped: - char = '\\n' # Replace the newline character with the escape sequence. - elif char == '\\': + elif char == "\n" and not escaped: + char = "\\n" # Replace the newline character with the escape sequence. + elif char == "\\": escaped = not escaped else: escaped = False @@ -30,17 +30,17 @@ def parse_partial_json(s): if char == '"': is_inside_string = True escaped = False - elif char == '{': - stack.append('}') - elif char == '[': - stack.append(']') - elif char == '}' or char == ']': + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char == "}" or char == "]": if stack and stack[-1] == char: stack.pop() else: # Mismatched closing character; the input is malformed. return None - + # Append the processed character to the new string. new_s += char diff --git a/interpreter/utils/scan_code.py b/interpreter/utils/scan_code.py index fa5db98431..89dccd5649 100644 --- a/interpreter/utils/scan_code.py +++ b/interpreter/utils/scan_code.py @@ -1,10 +1,11 @@ import os import subprocess + from yaspin import yaspin from yaspin.spinners import Spinners -from .temporary_file import create_temporary_file, cleanup_temporary_file from ..code_interpreters.language_map import language_map +from .temporary_file import cleanup_temporary_file, create_temporary_file def get_language_file_extension(language_name): diff --git a/interpreter/utils/truncate_output.py b/interpreter/utils/truncate_output.py index 1ca324372c..08ed843b68 100644 --- a/interpreter/utils/truncate_output.py +++ b/interpreter/utils/truncate_output.py @@ -1,15 +1,15 @@ def truncate_output(data, max_output_chars=2000): - needs_truncation = False + needs_truncation = False - message = f'Output truncated. Showing the last {max_output_chars} characters.\n\n' + message = f"Output truncated. Showing the last {max_output_chars} characters.\n\n" - # Remove previous truncation message if it exists - if data.startswith(message): - data = data[len(message):] - needs_truncation = True + # Remove previous truncation message if it exists + if data.startswith(message): + data = data[len(message) :] + needs_truncation = True - # If data exceeds max length, truncate it and add message - if len(data) > max_output_chars or needs_truncation: - data = message + data[-max_output_chars:] + # If data exceeds max length, truncate it and add message + if len(data) > max_output_chars or needs_truncation: + data = message + data[-max_output_chars:] - return data \ No newline at end of file + return data diff --git a/interpreter/utils/vector_search.py b/interpreter/utils/vector_search.py index d610231ea4..22efccd2f3 100644 --- a/interpreter/utils/vector_search.py +++ b/interpreter/utils/vector_search.py @@ -1,5 +1,6 @@ -from chromadb.utils.distance_functions import cosine import numpy as np +from chromadb.utils.distance_functions import cosine + def search(query, db, embed_function, num_results=2): """ @@ -19,10 +20,12 @@ def search(query, db, embed_function, num_results=2): query_embedding = embed_function(query) # Calculate the cosine distance between the query embedding and each embedding in the database - distances = {value: cosine(query_embedding, embedding) for value, embedding in db.items()} + distances = { + value: cosine(query_embedding, embedding) for value, embedding in db.items() + } # Sort the values by their distance to the query, and select the top num_results most_similar_values = sorted(distances, key=distances.get)[:num_results] # Return the most similar values - return most_similar_values \ No newline at end of file + return most_similar_values diff --git a/pyproject.toml b/pyproject.toml index d69fe0be6c..d11e3ab50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ wget = "^3.2" yaspin = "^3.0.1" [tool.poetry.group.dev.dependencies] +black = "^23.10.1" +isort = "^5.12.0" +pre-commit = "^3.5.0" pytest = "^7.4.0" [build-system] @@ -40,4 +43,12 @@ build-backend = "poetry.core.masonry.api" interpreter = "interpreter:cli" [tool.poetry.extras] -safe = ["semgrep"] \ No newline at end of file +safe = ["semgrep"] + +[tool.black] +target-version = ['py311'] + +[tool.isort] +profile = "black" +multi_line_output = 3 +include_trailing_comma = true diff --git a/tests/test_interpreter.py b/tests/test_interpreter.py index 0a13a6deab..40693074db 100644 --- a/tests/test_interpreter.py +++ b/tests/test_interpreter.py @@ -1,10 +1,11 @@ import os -from random import randint import time +from random import randint + import pytest + import interpreter -from interpreter.utils.count_tokens import count_tokens, count_messages_tokens -import time +from interpreter.utils.count_tokens import count_messages_tokens, count_tokens interpreter.auto_run = True interpreter.model = "gpt-4" @@ -32,7 +33,7 @@ def test_config_loading(): # path manipulation to get the actual path to the config file or our config # loader will try to load from the wrong directory and fail currentPath = os.path.dirname(os.path.abspath(__file__)) - config_path=os.path.join(currentPath, './config.test.yaml') + config_path = os.path.join(currentPath, "./config.test.yaml") interpreter.extend_config(config_path=config_path) @@ -43,6 +44,7 @@ def test_config_loading(): assert temperature_ok and model_ok and debug_mode_ok + def test_system_message_appending(): ping_system_message = ( "Respond to a `ping` with a `pong`. No code. No explanations. Just `pong`." @@ -67,21 +69,29 @@ def test_reset(): def test_token_counter(): - system_tokens = count_tokens(text=interpreter.system_message, model=interpreter.model) - + system_tokens = count_tokens( + text=interpreter.system_message, model=interpreter.model + ) + prompt = "How many tokens is this?" prompt_tokens = count_tokens(text=prompt, model=interpreter.model) - messages = [{"role": "system", "message": interpreter.system_message}] + interpreter.messages + messages = [ + {"role": "system", "message": interpreter.system_message} + ] + interpreter.messages - system_token_test = count_messages_tokens(messages=messages, model=interpreter.model) + system_token_test = count_messages_tokens( + messages=messages, model=interpreter.model + ) system_tokens_ok = system_tokens == system_token_test[0] messages.append({"role": "user", "message": prompt}) - prompt_token_test = count_messages_tokens(messages=messages, model=interpreter.model) + prompt_token_test = count_messages_tokens( + messages=messages, model=interpreter.model + ) prompt_tokens_ok = system_tokens + prompt_tokens == prompt_token_test[0] @@ -100,6 +110,7 @@ def test_hello_world(): {"role": "assistant", "message": hello_world_response}, ] + @pytest.mark.skip(reason="Math is hard") def test_math(): # we'll generate random integers between this min and max in our math tests @@ -127,7 +138,10 @@ def test_delayed_exec(): """Can you write a single block of code and execute it that prints something, then delays 1 second, then prints something else? No talk just code. Thanks!""" ) -@pytest.mark.skip(reason="This works fine when I run it but fails frequently in Github Actions... will look into it after the hackathon") + +@pytest.mark.skip( + reason="This works fine when I run it but fails frequently in Github Actions... will look into it after the hackathon" +) def test_nested_loops_and_multiple_newlines(): interpreter.chat( """Can you write a nested for loop in python and shell and run them? Don't forget to properly format your shell script and use semicolons where necessary. Also put 1-3 newlines between each line in the code. Only generate and execute the code. No explanations. Thanks!"""