Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 47 additions & 74 deletions api/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from adalflow.core.db import LocalDB
from api.config import configs
from api.ollama_patch import OllamaDocumentProcessor
from urllib.parse import urlparse, urlunparse, quote

# Configure logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,7 +45,7 @@ def count_tokens(text: str, local_ollama: bool = False) -> int:
# Rough approximation: 4 characters per token
return len(text) // 4

def download_repo(repo_url: str, local_path: str, access_token: str = None):
def download_repo(repo_url: str, local_path: str, type: str = "github", access_token: str = None) -> str:
"""
Downloads a Git repository (GitHub, GitLab, or Bitbucket) to a specified local path.

Expand Down Expand Up @@ -78,16 +79,17 @@ def download_repo(repo_url: str, local_path: str, access_token: str = None):
# Prepare the clone URL with access token if provided
clone_url = repo_url
if access_token:
parsed = urlparse(repo_url)
# Determine the repository type and format the URL accordingly
if "github.com" in repo_url:
if type == "github":
# Format: https://{token}@github.com/owner/repo.git
clone_url = repo_url.replace("https://", f"https://{access_token}@")
elif "gitlab.com" in repo_url:
# Format: https://oauth2:{token}@gitlab.com/owner/repo.git
clone_url = repo_url.replace("https://", f"https://oauth2:{access_token}@")
elif "bitbucket.org" in repo_url:
clone_url = urlunparse((parsed.scheme, f"{access_token}@{parsed.netloc}", parsed.path, '', '', ''))
elif type == "gitlab":
# Format: https://oauth2:{token}@gitlab.com/owner/repo.git
clone_url = urlunparse((parsed.scheme, f"oauth2:{access_token}@{parsed.netloc}", parsed.path, '', '', ''))
elif type == "bitbucket":
# Format: https://{token}@bitbucket.org/owner/repo.git
clone_url = repo_url.replace("https://", f"https://{access_token}@")
clone_url = urlunparse((parsed.scheme, f"{access_token}@{parsed.netloc}", parsed.path, '', '', ''))
logger.info("Using access token for authentication")

# Clone the repository
Expand Down Expand Up @@ -370,46 +372,40 @@ def get_github_file_content(repo_url: str, file_path: str, access_token: str = N

def get_gitlab_file_content(repo_url: str, file_path: str, access_token: str = None) -> str:
"""
Retrieves the content of a file from a GitLab repository using the GitLab API.
Retrieves the content of a file from a GitLab repository (cloud or self-hosted).

Args:
repo_url (str): The URL of the GitLab repository (e.g., "https://gitlab.com/username/repo")
file_path (str): The path to the file within the repository (e.g., "src/main.py")
access_token (str, optional): GitLab personal access token for private repositories
repo_url (str): The GitLab repo URL (e.g., "https://gitlab.com/username/repo" or "http://localhost/group/project")
file_path (str): File path within the repository (e.g., "src/main.py")
access_token (str, optional): GitLab personal access token

Returns:
str: The content of the file as a string
str: File content

Raises:
ValueError: If the file cannot be fetched or if the URL is not a valid GitLab URL
ValueError: If anything fails
"""
try:
# Extract owner and repo name from GitLab URL
if not (repo_url.startswith("https://gitlab.com/") or repo_url.startswith("http://gitlab.com/")):
# Parse and validate the URL
parsed_url = urlparse(repo_url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError("Not a valid GitLab repository URL")

parts = repo_url.rstrip('/').split('/')
if len(parts) < 5:
raise ValueError("Invalid GitLab URL format")

# For GitLab, the URL format can be:
# - https://gitlab.com/username/repo
# - https://gitlab.com/group/subgroup/repo
# We need to extract the project path with namespace
gitlab_domain = f"{parsed_url.scheme}://{parsed_url.netloc}"
if parsed_url.port not in (None, 80, 443):
gitlab_domain += f":{parsed_url.port}"
path_parts = parsed_url.path.strip("/").split("/")
if len(path_parts) < 2:
raise ValueError("Invalid GitLab URL format — expected something like https://gitlab.domain.com/group/project")

# Remove the domain part
path_parts = parts[3:]
# Join the remaining parts to get the project path with namespace
project_path = '/'.join(path_parts).replace(".git", "")
# URL encode the path for API use
encoded_project_path = project_path.replace('/', '%2F')
# Build project path and encode for API
project_path = "/".join(path_parts).replace(".git", "")
encoded_project_path = quote(project_path, safe='')

# Use GitLab API to get file content
# The API endpoint for getting file content is: /api/v4/projects/{encoded_project_path}/repository/files/{encoded_file_path}/raw
encoded_file_path = file_path.replace('/', '%2F')
api_url = f"https://gitlab.com/api/v4/projects/{encoded_project_path}/repository/files/{encoded_file_path}/raw?ref=main"
# Encode file path
encoded_file_path = quote(file_path, safe='')

# Prepare curl command with authentication if token is provided
api_url = f"{gitlab_domain}/api/v4/projects/{encoded_project_path}/repository/files/{encoded_file_path}/raw?ref={default_branch}"
curl_cmd = ["curl", "-s"]
if access_token:
curl_cmd.extend(["-H", f"PRIVATE-TOKEN: {access_token}"])
Expand All @@ -423,37 +419,14 @@ def get_gitlab_file_content(repo_url: str, file_path: str, access_token: str = N
stderr=subprocess.PIPE,
)

# GitLab API returns the raw file content directly
content = result.stdout.decode("utf-8")

# Check if we got an error response (GitLab returns JSON for errors)
if content.startswith('{') and '"message":' in content:
# Check for GitLab error response (JSON instead of raw file)
if content.startswith("{") and '"message":' in content:
try:
error_data = json.loads(content)
if "message" in error_data:
# Try with 'master' branch if 'main' failed
api_url = f"https://gitlab.com/api/v4/projects/{encoded_project_path}/repository/files/{encoded_file_path}/raw?ref=master"
logger.info(f"Retrying with master branch: {api_url}")

# Prepare curl command for retry
curl_cmd = ["curl", "-s"]
if access_token:
curl_cmd.extend(["-H", f"PRIVATE-TOKEN: {access_token}"])
curl_cmd.append(api_url)

result = subprocess.run(
curl_cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
content = result.stdout.decode("utf-8")

# Check again for error
if content.startswith('{') and '"message":' in content:
error_data = json.loads(content)
if "message" in error_data:
raise ValueError(f"GitLab API error: {error_data['message']}")
raise ValueError(f"GitLab API error: {error_data['message']}")
except json.JSONDecodeError:
# If it's not valid JSON, it's probably the file content
pass
Expand Down Expand Up @@ -532,7 +505,7 @@ def get_bitbucket_file_content(repo_url: str, file_path: str, access_token: str
raise ValueError(f"Failed to get file content: {str(e)}")


def get_file_content(repo_url: str, file_path: str, access_token: str = None) -> str:
def get_file_content(repo_url: str, file_path: str, type: str = "github", access_token: str = None) -> str:
"""
Retrieves the content of a file from a Git repository (GitHub or GitLab).

Expand All @@ -547,11 +520,11 @@ def get_file_content(repo_url: str, file_path: str, access_token: str = None) ->
Raises:
ValueError: If the file cannot be fetched or if the URL is not valid
"""
if "github.com" in repo_url:
if type == "github":
return get_github_file_content(repo_url, file_path, access_token)
elif "gitlab.com" in repo_url:
elif type == "gitlab":
return get_gitlab_file_content(repo_url, file_path, access_token)
elif "bitbucket.org" in repo_url:
elif type == "bitbucket":
return get_bitbucket_file_content(repo_url, file_path, access_token)
else:
raise ValueError("Unsupported repository URL. Only GitHub and GitLab are supported.")
Expand All @@ -566,7 +539,7 @@ def __init__(self):
self.repo_url_or_path = None
self.repo_paths = None

def prepare_database(self, repo_url_or_path: str, access_token: str = None, local_ollama: bool = False,
def prepare_database(self, repo_url_or_path: str, type: str = "github", access_token: str = None, local_ollama: bool = False,
excluded_dirs: List[str] = None, excluded_files: List[str] = None) -> List[Document]:
"""
Create a new database from the repository.
Expand All @@ -582,7 +555,7 @@ def prepare_database(self, repo_url_or_path: str, access_token: str = None, loca
List[Document]: List of Document objects
"""
self.reset_database()
self._create_repo(repo_url_or_path, access_token)
self._create_repo(repo_url_or_path, type, access_token)
return self.prepare_db_index(local_ollama=local_ollama, excluded_dirs=excluded_dirs, excluded_files=excluded_files)

def reset_database(self):
Expand All @@ -593,7 +566,7 @@ def reset_database(self):
self.repo_url_or_path = None
self.repo_paths = None

def _create_repo(self, repo_url_or_path: str, access_token: str = None) -> None:
def _create_repo(self, repo_url_or_path: str, type: str = "github", access_token: str = None) -> None:
"""
Download and prepare all paths.
Paths:
Expand All @@ -613,14 +586,14 @@ def _create_repo(self, repo_url_or_path: str, access_token: str = None) -> None:
# url
if repo_url_or_path.startswith("https://") or repo_url_or_path.startswith("http://"):
# Extract repo name based on the URL format
if "github.com" in repo_url_or_path:
if type == "github":
# GitHub URL format: https://github.com/owner/repo
repo_name = repo_url_or_path.split("/")[-1].replace(".git", "")
elif "gitlab.com" in repo_url_or_path:
elif type == "gitlab":
# GitLab URL format: https://gitlab.com/owner/repo or https://gitlab.com/group/subgroup/repo
# Use the last part of the URL as the repo name
repo_name = repo_url_or_path.split("/")[-1].replace(".git", "")
elif "bitbucket.org" in repo_url_or_path:
elif type == "bitbucket":
# Bitbucket URL format: https://bitbucket.org/owner/repo
repo_name = repo_url_or_path.split("/")[-1].replace(".git", "")
else:
Expand All @@ -632,7 +605,7 @@ def _create_repo(self, repo_url_or_path: str, access_token: str = None) -> None:
# Check if the repository directory already exists and is not empty
if not (os.path.exists(save_repo_dir) and os.listdir(save_repo_dir)):
# Only download if the repository doesn't exist or is empty
download_repo(repo_url_or_path, save_repo_dir, access_token)
download_repo(repo_url_or_path, save_repo_dir, type, access_token)
else:
logger.info(f"Repository already exists at {save_repo_dir}. Using existing repository.")
else: # local path
Expand Down Expand Up @@ -695,7 +668,7 @@ def prepare_db_index(self, local_ollama: bool = False, excluded_dirs: List[str]
logger.info(f"Total transformed documents: {len(transformed_docs)}")
return transformed_docs

def prepare_retriever(self, repo_url_or_path: str, access_token: str = None):
def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_token: str = None):
"""
Prepare the retriever for a repository.
This is a compatibility method for the isolated API.
Expand All @@ -707,4 +680,4 @@ def prepare_retriever(self, repo_url_or_path: str, access_token: str = None):
Returns:
List[Document]: List of Document objects
"""
return self.prepare_database(repo_url_or_path, access_token)
return self.prepare_database(repo_url_or_path, type, access_token)
3 changes: 2 additions & 1 deletion api/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def initialize_db_manager(self):
self.db_manager = DatabaseManager()
self.transformed_docs = []

def prepare_retriever(self, repo_url_or_path: str, access_token: str = None, local_ollama: bool = False,
def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_token: str = None, local_ollama: bool = False,
excluded_dirs: List[str] = None, excluded_files: List[str] = None):
"""
Prepare the retriever for a repository.
Expand All @@ -304,6 +304,7 @@ def prepare_retriever(self, repo_url_or_path: str, access_token: str = None, loc
self.repo_url_or_path = repo_url_or_path
self.transformed_docs = self.db_manager.prepare_database(
repo_url_or_path,
type,
access_token,
local_ollama=local_ollama,
excluded_dirs=excluded_dirs,
Expand Down
36 changes: 5 additions & 31 deletions api/simple_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ class ChatCompletionRequest(BaseModel):
repo_url: str = Field(..., description="URL of the repository to query")
messages: List[ChatMessage] = Field(..., description="List of chat messages")
filePath: Optional[str] = Field(None, description="Optional path to a file in the repository to include in the prompt")
github_token: Optional[str] = Field(None, description="GitHub personal access token for private repositories")
gitlab_token: Optional[str] = Field(None, description="GitLab personal access token for private repositories")
bitbucket_token: Optional[str] = Field(None, description="Bitbucket personal access token for private repositories")
token: Optional[str] = Field(None, description="Personal access token for private repositories")
type: Optional[str] = Field("github", description="Type of repository (e.g., 'github', 'gitlab', 'bitbucket')")

# model parameters
provider: str = Field("google", description="Model provider (google, openai, openrouter, ollama)")
Expand Down Expand Up @@ -91,18 +90,6 @@ async def chat_completions_stream(request: ChatCompletionRequest):
try:
request_rag = RAG(provider=request.provider, model=request.model)

# Determine which access token to use based on the repository URL
access_token = None
if "github.com" in request.repo_url and request.github_token:
access_token = request.github_token
logger.info("Using GitHub token for authentication")
elif "gitlab.com" in request.repo_url and request.gitlab_token:
access_token = request.gitlab_token
logger.info("Using GitLab token for authentication")
elif "bitbucket.org" in request.repo_url and request.bitbucket_token:
access_token = request.bitbucket_token
logger.info("Using Bitbucket token for authentication")

# Extract custom file filter parameters if provided
excluded_dirs = None
excluded_files = None
Expand All @@ -113,7 +100,7 @@ async def chat_completions_stream(request: ChatCompletionRequest):
excluded_files = [unquote(file_pattern) for file_pattern in request.excluded_files.split('\n') if file_pattern.strip()]
logger.info(f"Using custom excluded files: {excluded_files}")

request_rag.prepare_retriever(request.repo_url, access_token, False, excluded_dirs, excluded_files)
request_rag.prepare_retriever(request.repo_url, request.type, request.token, False, excluded_dirs, excluded_files)
logger.info(f"Retriever prepared for {request.repo_url}")
except Exception as e:
logger.error(f"Error preparing retriever: {str(e)}")
Expand Down Expand Up @@ -233,11 +220,7 @@ async def chat_completions_stream(request: ChatCompletionRequest):
repo_name = repo_url.split("/")[-1] if "/" in repo_url else repo_url

# Determine repository type
repo_type = "GitHub"
if "gitlab.com" in repo_url:
repo_type = "GitLab"
elif "bitbucket.org" in repo_url:
repo_type = "Bitbucket"
repo_type = request.type

# Get language information
language_code = request.language or "en"
Expand Down Expand Up @@ -396,16 +379,7 @@ async def chat_completions_stream(request: ChatCompletionRequest):
file_content = ""
if request.filePath:
try:
# Determine which access token to use
access_token = None
if "github.com" in request.repo_url and request.github_token:
access_token = request.github_token
elif "gitlab.com" in request.repo_url and request.gitlab_token:
access_token = request.gitlab_token
elif "bitbucket.org" in request.repo_url and request.bitbucket_token:
access_token = request.bitbucket_token

file_content = get_file_content(request.repo_url, request.filePath, access_token)
file_content = get_file_content(request.repo_url, request.filePath, request.type, request.token)
logger.info(f"Successfully retrieved content for file: {request.filePath}")
except Exception as e:
logger.error(f"Error retrieving file content: {str(e)}")
Expand Down
Loading