Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
96e213e
new prompts
rohan-gopalam Mar 28, 2025
a73eec4
WE DID IT
rohan-gopalam Apr 1, 2025
16ded40
kernel prompts script
annbypark Apr 1, 2025
5794270
working prompt generator script
annbypark Apr 1, 2025
eb1f965
prompt generation script with txt files
annbypark Apr 1, 2025
927949a
all prompts generated
annbypark Apr 1, 2025
753b096
kinda
rohan-gopalam Apr 1, 2025
be61967
Merge branch 'single_pass' of github.com:arjun-banerjee/torch2nki int…
rohan-gopalam Apr 1, 2025
64fde00
cleaner term version
rohan-gopalam Apr 2, 2025
34b161b
deals w output diffs
rohan-gopalam Apr 2, 2025
6c9fb3b
Midsem updates
rohan-gopalam Apr 2, 2025
c3b89c2
compiling on baremetal
rohan-gopalam Apr 5, 2025
3688132
testing all element wise for tiling
rahulvijay04 Apr 5, 2025
a69d1c3
single pass pipeline with tiling
rohan-gopalam Apr 14, 2025
60b5bc3
Merge single_pass into single_pass_commit
Apr 15, 2025
7d37524
Merged single_pass into single_pass_commit using ours
Apr 15, 2025
f074906
deleted files to clean up for pr
rohan-gopalam Apr 15, 2025
7b86474
Merge branch 'single_pass_commit' of github.com:arjun-banerjee/torch2…
rohan-gopalam Apr 15, 2025
1975d2d
clean up for pr
rohan-gopalam Apr 15, 2025
af61c8d
add prompts folder
Apr 17, 2025
b976b16
kernel prompt takes pytorch documentation instead of calling openai
annbypark Apr 17, 2025
6fc4f0c
current
rohan-gopalam May 14, 2025
79a61d7
Merge branch 'single_pass_commit' of github.com:arjun-banerjee/torch2…
rohan-gopalam May 14, 2025
b3e9b17
multielement kernels
rohan-gopalam May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
921 changes: 921 additions & 0 deletions generation/langchain_single_pass/all_in_one_generator.py

Large diffs are not rendered by default.

888 changes: 888 additions & 0 deletions generation/langchain_single_pass/all_in_one_generator_new.py

Large diffs are not rendered by default.

128 changes: 128 additions & 0 deletions generation/langchain_single_pass/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@
import datetime
import re

def update_function_name_in_text(text, new_name):
"""
Updates the function name in the function header of a text string.

The function expects the function header to follow this format:
def old_function_name(arguments):
<body lines>

Args:
text (str): The text content to update
new_name (str): New function name to replace the old one with

Returns:
str: The updated text content with the new function name
"""
# Updated regex to capture standard Python function definitions
pattern = r'^(def\s+)([^\s(]+)(\s*\(.*\):)' # Matches 'def function_name(args):'
# Replace with new function name while preserving 'def' and arguments
replacement = r'\1' + new_name + r'\3'
# Replace the first occurrence of the function definition
new_text = re.sub(pattern, replacement, text, count=1, flags=re.MULTILINE)

return new_text


def extract_kernel_from_llm_response(content):
"""
Expand Down Expand Up @@ -60,3 +84,107 @@ def log_to_file(log_file_path, message, append=True):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file_path, mode, encoding="utf-8") as f:
f.write(f"[{timestamp}] {message}\n")

class ExecutionServer:
"""A server capable of running test functions with specified device and NKI function."""

def __init__(self, device='cpu'):
"""Initialize the execution server.

Args:
device: The device to run tests on (default: 'cpu')
"""
self.device = device
import tests
self.tests = tests

@staticmethod
def load_kernel_module(kernel_path):
"""Dynamically load the kernel module from the given path."""
import importlib.util
import os
import sys

# Remove .py extension if present
if kernel_path.endswith('.py'):
kernel_path = kernel_path[:-3]

# Get module name from path
module_name = os.path.basename(kernel_path)

# Import the module
spec = importlib.util.spec_from_file_location(module_name, kernel_path + '.py')
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module

def run(self, test_func_name, kernel_func_name, kernel_module_path, output_file):
"""Run a test function with the specified NKI function and save output.

Args:
test_func_name: The name of the test function from tests.py to run
kernel_module_path: Path to the kernel module to test
output_file: Path to save the output to

Returns:
The combined stdout and stderr output from running the test
"""
import sys
from io import StringIO

# Load the kernel module
try:
kernel_module = self.load_kernel_module(kernel_module_path)
except Exception as e:
error = f"Error loading kernel module: {str(e)}"
with open(output_file, "w", encoding="utf-8") as f:
f.write(error)
return error

# Capture stdout and stderr
stdout = StringIO()
stderr = StringIO()
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = stdout, stderr

try:
test_func = getattr(self.tests, test_func_name)
# Get the kernel function - it should have the same name as the operator
kernel_func = getattr(kernel_module, kernel_func_name)
test_func(self.device, kernel_func)
except Exception as e:
print(f"Error running test: {str(e)}")
import traceback
traceback.print_exc()
finally:
# Restore stdout and stderr
sys.stdout, sys.stderr = old_stdout, old_stderr

# Get the output
output = stdout.getvalue() + "\n" + stderr.getvalue()
stdout.close()
stderr.close()

# Save to file
with open(output_file, "w", encoding="utf-8") as f:
f.write(output)

print(f"Test output saved to {output_file}")
return output

def run(test_func_name, kernel_func_name, kernel_module_path, output_file, device='cpu'):
"""Run a test function using an execution server and save output.

Args:
test_func_name: The name of the test function from tests.py to run (e.g., 'test_torch_addition')
kernel_func_name: The name of the kernel function to test (e.g., 'nki_vector_add')
kernel_module_path: Path to the kernel module to test
output_file: Path to save the output to
device: The device to run on (default: 'cpu')

Returns:
The combined stdout and stderr output from running the test
"""
server = ExecutionServer(device)
return server.run(test_func_name, kernel_func_name, kernel_module_path, output_file)
Loading