Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"dotenv": "^16.0.3",
"eslint": "8.34.0",
"eslint-config-next": "13.1.6",
"hnswlib-node": "^1.2.0",
"langchain": "0.0.15",
"hnswlib-node": "^1.4.2",
"langchain": "0.0.47",
"next": "13.1.6",
"openai": "^3.1.0",
"react": "18.2.0",
Expand All @@ -40,4 +40,4 @@
"tsx": "^3.12.3",
"typescript": "4.9.5"
}
}
}
65 changes: 38 additions & 27 deletions pages/api/chat-stream.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next'
import type { NextApiRequest, NextApiResponse } from "next";
import type { Server as HttpServer } from "http";
import type { Server as HttpsServer } from "https";
import { WebSocketServer } from 'ws';
import { WebSocketServer } from "ws";
import { HNSWLib } from "langchain/vectorstores";
import { OpenAIEmbeddings } from 'langchain/embeddings';
import { makeChain } from "./util";
import { OpenAIEmbeddings } from "langchain/embeddings";
import { formatHistory, makeChain } from "./util";

export default async function handler(req: NextApiRequest, res: NextApiResponse) {
export default async function handler(
req: NextApiRequest,
res: NextApiResponse
) {
if ((res.socket as any).server.wss) {
res.end();
return;
Expand All @@ -16,52 +19,60 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const server = (res.socket as any).server as HttpsServer | HttpServer;
const wss = new WebSocketServer({ noServer: true });
(res.socket as any).server.wss = wss;
server.on('upgrade', (req, socket, head) => {
if (!req.url?.includes('/_next/webpack-hmr')) {

server.on("upgrade", (req, socket, head) => {
if (!req.url?.includes("/_next/webpack-hmr")) {
wss.handleUpgrade(req, socket, head, (ws) => {
wss.emit('connection', ws, req);
wss.emit("connection", ws, req);
});
}
});

wss.on('connection', (ws) => {
const sendResponse = ({ sender, message, type }: { sender: string, message: string, type: string }) => {
wss.on("connection", (ws) => {
const sendResponse = ({
sender,
message,
type,
}: {
sender: string;
message: string;
type: string;
}) => {
ws.send(JSON.stringify({ sender, message, type }));
};

const onNewToken = (token: string) => {
sendResponse({ sender: 'bot', message: token, type: 'stream' });
}
const onNewToken = async (token: string) => {
sendResponse({ sender: "bot", message: token, type: "stream" });
};

const chainPromise = HNSWLib.load("data", new OpenAIEmbeddings()).then((vs) => makeChain(vs, onNewToken));
const chainPromise = HNSWLib.load("data", new OpenAIEmbeddings()).then(
(vs) => makeChain(vs, onNewToken)
);
const chatHistory: [string, string][] = [];
const encoder = new TextEncoder();


ws.on('message', async (data) => {
ws.on("message", async (data) => {
try {
const question = data.toString();
sendResponse({ sender: 'you', message: question, type: 'stream' });
sendResponse({ sender: "you", message: question, type: "stream" });

sendResponse({ sender: 'bot', message: "", type: 'start' });
sendResponse({ sender: "bot", message: "", type: "start" });
const chain = await chainPromise;

const result = await chain.call({
question,
chat_history: chatHistory,
question,
chat_history: formatHistory(chatHistory),
});
chatHistory.push([question, result.answer]);

sendResponse({ sender: 'bot', message: "", type: 'end' });
sendResponse({ sender: "bot", message: "", type: "end" });
} catch (e) {
sendResponse({
sender: 'bot',
message: "Sorry, something went wrong. Try again.",
type: 'error'
sender: "bot",
message: "Sorry, something went wrong. Try again.",
type: "error",
});
}
})
});
});

res.end();
Expand Down
6 changes: 3 additions & 3 deletions pages/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from "next";
import path from "path";
import { HNSWLib } from "langchain/vectorstores";
import { OpenAIEmbeddings } from "langchain/embeddings";
import { makeChain } from "./util";
import { formatHistory, makeChain } from "./util";

export default async function handler(
req: NextApiRequest,
Expand All @@ -27,14 +27,14 @@ export default async function handler(
};

sendData(JSON.stringify({ data: "" }));
const chain = makeChain(vectorstore, (token: string) => {
const chain = makeChain(vectorstore, async (token: string) => {
sendData(JSON.stringify({ data: token }));
});

try {
await chain.call({
question: body.question,
chat_history: body.history,
chat_history: formatHistory(body.history),
});
} catch (err) {
console.error(err);
Expand Down
61 changes: 40 additions & 21 deletions pages/api/util.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import { OpenAI } from "langchain/llms";
import { LLMChain, ChatVectorDBQAChain, loadQAChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models";
import {
LLMChain,
ConversationalRetrievalQAChain,
loadQAStuffChain,
} from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores";
import { PromptTemplate } from "langchain/prompts";
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
} from "langchain/prompts";
import { CallbackManager } from "langchain/callbacks";
import { AIChatMessage, HumanChatMessage } from "langchain/schema";

const CONDENSE_PROMPT = PromptTemplate.fromTemplate(`Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:`);
const CONDENSE_PROMPT = ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(
`Given the following conversation between a user and an assistant, rephrase the last question from the user to be a standalone question.`
),
new MessagesPlaceholder("chat_history"),
HumanMessagePromptTemplate.fromTemplate(`Last question: {question}`),
]);

const QA_PROMPT = PromptTemplate.fromTemplate(
`You are an AI assistant for the open source library LangChain. The documentation is located at https://langchain.readthedocs.io.
Expand All @@ -21,28 +34,34 @@ Question: {question}
=========
{context}
=========
Answer in Markdown:`);
Answer in Markdown:`
);

export const makeChain = (vectorstore: HNSWLib, onTokenStream?: (token: string) => void) => {
export const makeChain = (
vectorstore: HNSWLib,
onTokenStream?: (token: string) => Promise<void>
) => {
const questionGenerator = new LLMChain({
llm: new OpenAI({ temperature: 0 }),
llm: new ChatOpenAI({ temperature: 0 }),
prompt: CONDENSE_PROMPT,
});
const docChain = loadQAChain(
new OpenAI({
const docChain = loadQAStuffChain(
new ChatOpenAI({
temperature: 0,
streaming: Boolean(onTokenStream),
callbackManager: {
handleNewToken: onTokenStream,
}
callbackManager: CallbackManager.fromHandlers({
handleLLMNewToken: onTokenStream,
}),
}),
{ prompt: QA_PROMPT },
{ prompt: QA_PROMPT }
);

return new ChatVectorDBQAChain({
vectorstore,
return new ConversationalRetrievalQAChain({
retriever: vectorstore.asRetriever(),
combineDocumentsChain: docChain,
questionGeneratorChain: questionGenerator,
});
}
};

export const formatHistory = (history: [string, string][]) =>
history.flatMap(([q, a]) => [new HumanChatMessage(q), new AIChatMessage(a)]);
Loading