Skip to content

feat: bring your own llm #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/postgres-new/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public/sw.mjs
51 changes: 5 additions & 46 deletions apps/postgres-new/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { createOpenAI } from '@ai-sdk/openai'
import { Ratelimit } from '@upstash/ratelimit'
import { kv } from '@vercel/kv'
import { convertToCoreMessages, streamText, ToolInvocation, ToolResultPart } from 'ai'
import { codeBlock } from 'common-tags'
import { convertToCoreTools, maxMessageContext, maxRowLimit, tools } from '~/lib/tools'
import { getSystemPrompt } from '~/lib/system-prompt'
import { convertToCoreTools, maxMessageContext, tools } from '~/lib/tools'
import { createClient } from '~/utils/supabase/server'
import { ChatInferenceEventToolResult, logEvent } from '~/utils/telemetry'

Expand Down Expand Up @@ -72,49 +72,8 @@ export async function POST(req: Request) {
const coreMessages = convertToCoreMessages(trimmedMessageContext)
const coreTools = convertToCoreTools(tools)

const result = await streamText({
system: codeBlock`
You are a helpful database assistant. Under the hood you have access to an in-browser Postgres database called PGlite (https://github.com/electric-sql/pglite).
Some special notes about this database:
- foreign data wrappers are not supported
- the following extensions are available:
- plpgsql [pre-enabled]
- vector (https://github.com/pgvector/pgvector) [pre-enabled]
- use <=> for cosine distance (default to this)
- use <#> for negative inner product
- use <-> for L2 distance
- use <+> for L1 distance
- note queried vectors will be truncated/redacted due to their size - export as CSV if the full vector is required

When generating tables, do the following:
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- Prefer 'text' over 'varchar'
- Keep explanations brief but helpful
- Don't repeat yourself after creating the table

When creating sample data:
- Make the data realistic, including joined data
- Check for existing records/conflicts in the table

When querying data, limit to 5 by default. The maximum number of rows you're allowed to fetch is ${maxRowLimit} (to protect AI from token abuse).
If the user needs to fetch more than ${maxRowLimit} rows at once, they can export the query as a CSV.

When performing FTS, always use 'simple' (languages aren't available).

When importing CSVs try to solve the problem yourself (eg. use a generic text column, then refine)
vs. asking the user to change the CSV. No need to select rows after importing.

You also know math. All math equations and expressions must be written in KaTex and must be wrapped in double dollar \`$$\`:
- Inline: $$\\sqrt{26}$$
- Multiline:
$$
\\sqrt{26}
$$

No images are allowed. Do not try to generate or link images, including base64 data URLs.

Feel free to suggest corrections for suspected typos.
`,
const result = streamText({
system: getSystemPrompt(),
model: openai(chatModel),
messages: coreMessages,
tools: coreTools,
Expand Down Expand Up @@ -158,7 +117,7 @@ export async function POST(req: Request) {
},
})

return result.toAIStreamResponse()
return result.toDataStreamResponse()
}

function getEventToolResult(toolResult: ToolResultPart): ChatInferenceEventToolResult | undefined {
Expand Down
16 changes: 16 additions & 0 deletions apps/postgres-new/components/app-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
import { legacyDomainHostname } from '~/lib/util'
import { parse, serialize } from '~/lib/websocket-protocol'
import { createClient } from '~/utils/supabase/client'
import { useModelProvider } from './model-provider/use-model-provider'

export type AppProps = PropsWithChildren

Expand Down Expand Up @@ -252,6 +253,9 @@ export default function AppProvider({ children }: AppProps) {
const [isLegacyDomain, setIsLegacyDomain] = useState(false)
const [isLegacyDomainRedirect, setIsLegacyDomainRedirect] = useState(false)

const [modelProviderError, setModelProviderError] = useState<string>()
const [isModelProviderDialogOpen, setIsModelProviderDialogOpen] = useState(false)

useEffect(() => {
const isLegacyDomain = window.location.hostname === legacyDomainHostname
const urlParams = new URLSearchParams(window.location.search)
Expand All @@ -263,12 +267,17 @@ export default function AppProvider({ children }: AppProps) {
setIsRenameDialogOpen(isLegacyDomain || isLegacyDomainRedirect)
}, [])

const modelProvider = useModelProvider()

return (
<AppContext.Provider
value={{
user,
isLoadingUser,
liveShare,
modelProvider,
modelProviderError,
setModelProviderError,
signIn,
signOut,
isSignInDialogOpen,
Expand All @@ -277,6 +286,8 @@ export default function AppProvider({ children }: AppProps) {
setIsRenameDialogOpen,
isRateLimited,
setIsRateLimited,
isModelProviderDialogOpen,
setIsModelProviderDialogOpen,
focusRef,
dbManager,
pgliteVersion,
Expand Down Expand Up @@ -305,6 +316,8 @@ export type AppContextValues = {
setIsRenameDialogOpen: (open: boolean) => void
isRateLimited: boolean
setIsRateLimited: (limited: boolean) => void
isModelProviderDialogOpen: boolean
setIsModelProviderDialogOpen: (open: boolean) => void
focusRef: RefObject<FocusHandle>
dbManager?: DbManager
pgliteVersion?: string
Expand All @@ -316,6 +329,9 @@ export type AppContextValues = {
clientIp: string | null
isLiveSharing: boolean
}
modelProvider: ReturnType<typeof useModelProvider>
modelProviderError?: string
setModelProviderError: (error: string | undefined) => void
isLegacyDomain: boolean
isLegacyDomainRedirect: boolean
}
Expand Down
24 changes: 24 additions & 0 deletions apps/postgres-new/components/byo-llm-button.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { Brain } from 'lucide-react'
import { useApp } from '~/components/app-provider'
import { Button } from '~/components/ui/button'

export type ByoLlmButtonProps = {
onClick?: () => void
}

export default function ByoLlmButton({ onClick }: ByoLlmButtonProps) {
const { setIsModelProviderDialogOpen } = useApp()

return (
<Button
className="gap-2 text-base"
onClick={() => {
onClick?.()
setIsModelProviderDialogOpen(true)
}}
>
<Brain size={18} strokeWidth={2} />
Bring your own LLM
</Button>
)
}
76 changes: 61 additions & 15 deletions apps/postgres-new/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { Message, generateId } from 'ai'
import { useChat } from 'ai/react'
import { AnimatePresence, m } from 'framer-motion'
import { ArrowDown, ArrowUp, Flame, Paperclip, PlugIcon, Square } from 'lucide-react'
import { AlertCircle, ArrowDown, ArrowUp, Flame, Paperclip, PlugIcon, Square } from 'lucide-react'
import {
FormEventHandler,
useCallback,
Expand All @@ -22,6 +22,7 @@ import { requestFileUpload } from '~/lib/util'
import { cn } from '~/lib/utils'
import { AiIconAnimation } from './ai-icon-animation'
import { useApp } from './app-provider'
import ByoLlmButton from './byo-llm-button'
import ChatMessage from './chat-message'
import { CopyableField } from './copyable-field'
import SignInButton from './sign-in-button'
Expand Down Expand Up @@ -51,8 +52,17 @@ export function getInitialMessages(tables: TablesData): Message[] {
}

export default function Chat() {
const { user, isLoadingUser, focusRef, setIsSignInDialogOpen, isRateLimited, liveShare } =
useApp()
const {
user,
isLoadingUser,
focusRef,
setIsSignInDialogOpen,
isRateLimited,
liveShare,
modelProvider,
modelProviderError,
setIsModelProviderDialogOpen,
} = useApp()
const [inputFocusState, setInputFocusState] = useState(false)

const {
Expand Down Expand Up @@ -155,7 +165,7 @@ export default function Chat() {
cursor: dropZoneCursor,
} = useDropZone({
async onDrop(files) {
if (!user) {
if (isAuthRequired) {
return
}

Expand Down Expand Up @@ -223,8 +233,10 @@ export default function Chat() {

const [isMessageAnimationComplete, setIsMessageAnimationComplete] = useState(false)

const isAuthRequired = user === undefined && modelProvider.state?.enabled !== true

const isChatEnabled =
!isLoadingMessages && !isLoadingSchema && user !== undefined && !liveShare.isLiveSharing
!isLoadingMessages && !isLoadingSchema && !isAuthRequired && !liveShare.isLiveSharing

const isSubmitEnabled = isChatEnabled && Boolean(input.trim())

Expand Down Expand Up @@ -293,6 +305,42 @@ export default function Chat() {
isLast={i === messages.length - 1}
/>
))}
<AnimatePresence initial={false}>
{modelProviderError && !isLoading && (
<m.div
layout="position"
className="flex flex-col gap-4 justify-start items-center max-w-96 p-4 bg-destructive rounded-md text-sm"
variants={{
hidden: { scale: 0 },
show: { scale: 1, transition: { delay: 0.5 } },
}}
initial="hidden"
animate="show"
exit="hidden"
>
<AlertCircle size={64} strokeWidth={1} />
<div className="flex flex-col items-center text-start gap-4">
<h3 className="font-bold">Whoops!</h3>
<p className="text-center">
There was an error connecting to your custom model provider:{' '}
{modelProviderError}
</p>
<p>
Double check your{' '}
<a
className="underline cursor-pointer"
onClick={() => {
setIsModelProviderDialogOpen(true)
}}
>
API info
</a>
.
</p>
</div>
</m.div>
)}
</AnimatePresence>
<AnimatePresence initial={false}>
{isRateLimited && !isLoading && (
<m.div
Expand Down Expand Up @@ -357,7 +405,7 @@ export default function Chat() {
</div>
) : (
<div className="h-full w-full max-w-4xl flex flex-col gap-10 justify-center items-center">
{user ? (
{!isAuthRequired ? (
<>
<LiveShareOverlay databaseId={databaseId} />
<m.h3
Expand All @@ -384,11 +432,10 @@ export default function Chat() {
animate="show"
>
<SignInButton />
<p className="font-lighter text-center">
To prevent abuse we ask you to sign in before chatting with AI.
</p>
or
<ByoLlmButton />
<p
className="underline cursor-pointer text-primary/50"
className="underline cursor-pointer text-sm text-primary/50"
onClick={() => {
setIsSignInDialogOpen(true)
}}
Expand Down Expand Up @@ -427,7 +474,7 @@ export default function Chat() {
</div>
<div className="flex flex-col items-center gap-3 pb-1 relative">
<AnimatePresence>
{!user && !isLoadingUser && isConversationStarted && (
{isAuthRequired && !isLoadingUser && isConversationStarted && (
<m.div
className="flex flex-col items-center gap-4 max-w-lg my-4"
variants={{
Expand All @@ -438,9 +485,8 @@ export default function Chat() {
exit="hidden"
>
<SignInButton />
<p className="font-lighter text-center text-sm">
To prevent abuse we ask you to sign in before chatting with AI.
</p>
or
<ByoLlmButton />
<p
className="underline cursor-pointer text-sm text-primary/50"
onClick={() => {
Expand Down Expand Up @@ -487,7 +533,7 @@ export default function Chat() {
onClick={async (e) => {
e.preventDefault()

if (!user) {
if (isAuthRequired) {
return
}

Expand Down
Loading