Skip to content

Commit 3db7e28

Browse files
committed
feat: add api function to cancel request
Signed-off-by: blob42 <[email protected]>
1 parent 0c70aab commit 3db7e28

File tree

8 files changed

+58
-42
lines changed

8 files changed

+58
-42
lines changed

lua/codegpt.lua

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
local config = require("codegpt.config")
2-
local models = require("codegpt.models")
1+
local Config = require("codegpt.config")
2+
local Models = require("codegpt.models")
3+
local Api = require("codegpt.api")
34
local M = {}
45

56
local Commands = require("codegpt.commands")
@@ -17,16 +18,16 @@ end
1718

1819
function M.run_cmd(opts)
1920
if opts.name and opts.name:match("^V") then
20-
config.popup_override = "vertical"
21+
Config.popup_override = "vertical"
2122
else
22-
config.popup_override = nil
23+
Config.popup_override = nil
2324
end
2425

2526
-- bang makes popup persistent until closed
2627
if opts.bang then
27-
config.persistent_override = true
28+
Config.persistent_override = true
2829
else
29-
config.persistent_override = false
30+
Config.persistent_override = false
3031
end
3132

3233
local text_selection, bounds = Utils.get_selected_lines(opts)
@@ -46,7 +47,7 @@ function M.run_cmd(opts)
4647
command_args = ""
4748
elseif text_selection == "" then
4849
command = "chat"
49-
elseif config.opts.commands[command] == nil then
50+
elseif Config.opts.commands[command] == nil then
5051
command = "code_edit"
5152
end
5253
elseif text_selection ~= "" and command_args == "" then
@@ -63,7 +64,8 @@ function M.run_cmd(opts)
6364
Commands.run_cmd(command, command_args, text_selection, bounds)
6465
end
6566

66-
M.setup = config.setup
67-
M.select_model = models.select_model
67+
M.setup = Config.setup
68+
M.select_model = Models.select_model
69+
M.cancel_request = Api.cancel_job
6870

6971
return M

lua/codegpt/api.lua

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
local Config = require("codegpt.config")
22

3-
local Api = {}
3+
local M = {}
44

5-
CODEGPT_CALLBACK_COUNTER = 0
5+
local CODEGPT_CALLBACK_COUNTER = 0
6+
M.current_job = nil
67

78
local status_index = 0
89
local timer = vim.uv.new_timer()
@@ -18,7 +19,7 @@ local function start_spinner_timer()
1819
)
1920
end
2021

21-
function Api.get_status(...)
22+
function M.get_status(...)
2223
local spinners = Config.opts.ui.spinners or { "", "", "", "", "", "" }
2324
local spinner_speed = Config.opts.ui.spinner_speed or 80
2425
local ms = vim.uv.hrtime() / 1000000
@@ -40,15 +41,15 @@ function Api.get_status(...)
4041
end
4142
end
4243

43-
function Api.run_started_hook()
44+
function M.run_started_hook()
4445
if Config.opts.hooks.request_started ~= nil then
4546
Config.opts.hooks.request_started()
4647
end
4748

4849
CODEGPT_CALLBACK_COUNTER = CODEGPT_CALLBACK_COUNTER + 1
4950
end
5051

51-
function Api.run_finished_hook()
52+
function M.run_finished_hook()
5253
if CODEGPT_CALLBACK_COUNTER > 0 then
5354
CODEGPT_CALLBACK_COUNTER = CODEGPT_CALLBACK_COUNTER - 1
5455
end
@@ -59,4 +60,11 @@ function Api.run_finished_hook()
5960
end
6061
end
6162

62-
return Api
63+
function M.cancel_job()
64+
if M.current_job ~= nil then
65+
M.current_job:shutdown()
66+
M.run_finished_hook()
67+
end
68+
end
69+
70+
return M

lua/codegpt/providers/anthropic.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function M.make_call(payload, cb)
112112
local url = "https://api.anthropic.com/v1/messages"
113113
local headers = M.make_headers()
114114
Api.run_started_hook()
115-
curl.post(url, {
115+
Api.current_job = curl.post(url, {
116116
body = payload_str,
117117
headers = headers,
118118
callback = function(response)

lua/codegpt/providers/azure.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function M.make_call(payload, cb)
198198
local url = Config.opts.connection.chat_completions_url
199199
local headers = M.make_headers()
200200
Api.run_started_hook()
201-
curl.post(url, {
201+
Api.current_job = curl.post(url, {
202202
body = payload_str,
203203
headers = headers,
204204
callback = function(response)

lua/codegpt/providers/groq.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function M.make_call(payload, cb)
126126
local url = "https://api.groq.com/openai/v1/chat/completions"
127127
local headers = M.make_headers()
128128
Api.run_started_hook()
129-
curl.post(url, {
129+
Api.current_job = curl.post(url, {
130130
body = payload_str,
131131
headers = headers,
132132
callback = function(response)

lua/codegpt/providers/ollama.lua

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ local Render = require("codegpt.template_render")
33
local Utils = require("codegpt.utils")
44
local Api = require("codegpt.api")
55
local Config = require("codegpt.config")
6-
local Tokens = require("codegpt.tokens")
6+
local tokens = require("codegpt.tokens")
77
local errors = require("codegpt.errors")
88

99
local M = {}
@@ -30,7 +30,7 @@ local function generate_messages(command, cmd_opts, command_args, text_selection
3030
end
3131

3232
local function get_max_tokens(max_tokens, prompt)
33-
local total_length = Tokens.get_tokens(prompt)
33+
local total_length = tokens.get_tokens(prompt)
3434

3535
if total_length >= max_tokens then
3636
error("Total length of messages exceeds max_tokens: " .. total_length .. " > " .. max_tokens)
@@ -160,7 +160,7 @@ function M.make_call(payload, cb)
160160
local url = Config.opts.connection.ollama_base_url:gsub("/$", "") .. "/api/chat"
161161
local headers = M.make_headers()
162162
Api.run_started_hook()
163-
curl.post(url, {
163+
Api.current_job = curl.post(url, {
164164
body = payload_str,
165165
headers = headers,
166166
callback = function(response)
@@ -173,13 +173,13 @@ function M.make_call(payload, cb)
173173
end
174174

175175
---@param payload table payload sent to api
176-
---@param stream_cb fun(data: table) callback to handle the resonse json stream
176+
---@param stream_cb fun(data: table, job: table) callback to handle the resonse json stream
177177
function M.make_stream_call(payload, stream_cb)
178178
local payload_str = vim.fn.json_encode(payload)
179179
local url = Config.opts.connection.ollama_base_url:gsub("/$", "") .. "/api/chat"
180180
local headers = M.make_headers()
181181
Api.run_started_hook()
182-
curl.post(url, {
182+
Api.current_job = curl.post(url, {
183183
body = payload_str,
184184
headers = headers,
185185
stream = function(error, data, job)

lua/codegpt/providers/openai.lua

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ local Render = require("codegpt.template_render")
33
local Utils = require("codegpt.utils")
44
local Api = require("codegpt.api")
55
local Config = require("codegpt.config")
6-
local Tokens = require("codegpt.tokens")
6+
local tokens = require("codegpt.tokens")
77
local errors = require("codegpt.errors")
88

99
-- TODO: handle streaming mode
@@ -32,7 +32,7 @@ local function generate_messages(command, cmd_opts, command_args, text_selection
3232
end
3333

3434
local function get_max_tokens(max_tokens, messages)
35-
local total_length = Tokens.get_tokens(messages)
35+
local total_length = tokens.get_tokens(messages)
3636

3737
if total_length >= max_tokens then
3838
error("Total length of messages exceeds max_tokens: " .. total_length .. " > " .. max_tokens)
@@ -151,7 +151,7 @@ function M.make_call(payload, cb)
151151
local url = Config.opts.connection.chat_completions_url
152152
local headers = M.make_headers()
153153
Api.run_started_hook()
154-
curl.post(url, {
154+
Api.current_job = curl.post(url, {
155155
body = payload_str,
156156
headers = headers,
157157
callback = function(response)

lua/codegpt/ui.lua

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
local Popup = require("nui.popup")
22
local Split = require("nui.split")
3-
local config = require("codegpt.config")
3+
local Config = require("codegpt.config")
44
local event = require("nui.utils.autocmd").event
55

66
local M = {}
@@ -14,7 +14,7 @@ local function create_horizontal()
1414
split = Split({
1515
relative = "editor",
1616
position = "bottom",
17-
size = config.opts.ui.horizontal_popup_size,
17+
size = Config.opts.ui.horizontal_popup_size,
1818
})
1919
end
2020

@@ -26,7 +26,7 @@ local function create_vertical()
2626
split = Split({
2727
relative = "editor",
2828
position = "right",
29-
size = config.opts.ui.vertical_popup_size,
29+
size = Config.opts.ui.vertical_popup_size,
3030
})
3131
end
3232

@@ -35,15 +35,15 @@ end
3535

3636
local function create_floating()
3737
if not popup then
38-
local window_options = config.opts.ui.popup_window_options
38+
local window_options = Config.opts.ui.popup_window_options
3939
if window_options == nil then
4040
window_options = {}
4141
end
4242

4343
local popupOpts = {
4444
enter = true,
4545
focusable = true,
46-
border = config.opts.ui.popup_border,
46+
border = Config.opts.ui.popup_border,
4747
position = "50%",
4848
size = {
4949
width = "80%",
@@ -58,13 +58,13 @@ local function create_floating()
5858
popup = Popup(popupOpts)
5959
end
6060

61-
popup:update_layout(config.opts.ui.popup_options)
61+
popup:update_layout(Config.opts.ui.popup_options)
6262

6363
return popup
6464
end
6565

6666
local function create_window()
67-
local popup_type = config.popup_override or config.opts.ui.popup_type
67+
local popup_type = Config.popup_override or Config.opts.ui.popup_type
6868
local ui_elem = nil
6969
if popup_type == "horizontal" then
7070
ui_elem = create_horizontal()
@@ -78,24 +78,27 @@ local function create_window()
7878
end
7979

8080
function M.popup(job, lines, filetype, bufnr, start_row, start_col, end_row, end_col)
81+
if job ~= nil and job.is_shutdown then
82+
return
83+
end
8184
local ui_elem = create_window()
8285
-- mount/open the component
8386
ui_elem:mount()
8487

85-
if not (config.persistent_override or config.opts.ui.persistent) then
88+
if not (Config.persistent_override or Config.opts.ui.persistent) then
8689
-- unmount component when cursor leaves buffer
8790
ui_elem:on(event.BufLeave, function()
8891
ui_elem:unmount()
8992
end)
9093
end
9194

9295
-- unmount component when key 'q'
93-
ui_elem:map("n", config.opts.ui.actions.quit, function()
96+
ui_elem:map("n", Config.opts.ui.actions.quit, function()
9497
ui_elem:unmount()
9598
end, { noremap = true, silent = true })
9699
--
97100
-- cancel job if actions.cancel is called
98-
ui_elem:map("n", config.opts.ui.actions.cancel, function()
101+
ui_elem:map("n", Config.opts.ui.actions.cancel, function()
99102
job:shutdown()
100103
end, { noremap = true, silent = true })
101104

@@ -104,19 +107,19 @@ function M.popup(job, lines, filetype, bufnr, start_row, start_col, end_row, end
104107
vim.api.nvim_buf_set_lines(ui_elem.bufnr, 0, 1, false, lines)
105108

106109
-- replace lines when ctrl-o pressed
107-
ui_elem:map("n", config.opts.ui.actions.use_as_output, function()
110+
ui_elem:map("n", Config.opts.ui.actions.use_as_output, function()
108111
vim.api.nvim_buf_set_text(bufnr, start_row, start_col, end_row, end_col, lines)
109112
ui_elem:unmount()
110113
end)
111114

112115
-- selecting all the content when ctrl-i is pressed
113116
-- so the user can proceed with another API request
114-
ui_elem:map("n", config.opts.ui.actions.use_as_input, function()
117+
ui_elem:map("n", Config.opts.ui.actions.use_as_input, function()
115118
vim.api.nvim_feedkeys("ggVG:Chat ", "n", false)
116119
end, { noremap = false })
117120

118121
-- mapping custom commands
119-
for _, command in ipairs(config.opts.ui.actions.custom) do
122+
for _, command in ipairs(Config.opts.ui.actions.custom) do
120123
ui_elem:map(command[1], command[2], command[3], command[4])
121124
end
122125
end
@@ -125,27 +128,30 @@ local streaming = false
125128
local stream_ui_elem = nil
126129

127130
function M.popup_stream(job, stream, filetype, bufnr, start_row, start_col, end_row, end_col)
131+
if job ~= nil and job.is_shutdown then
132+
return
133+
end
128134
if not streaming then
129135
streaming = true
130136
stream_ui_elem = create_window()
131137

132138
-- mount/open the component
133139
stream_ui_elem:mount()
134140

135-
if not (config.persistent_override or config.opts.ui.persistent) then
141+
if not (Config.persistent_override or Config.opts.ui.persistent) then
136142
-- unmount component when cursor leaves buffer
137143
stream_ui_elem:on(event.BufLeave, function()
138144
stream_ui_elem:unmount()
139145
end)
140146
end
141147

142148
-- unmount component when key 'q'
143-
stream_ui_elem:map("n", config.opts.ui.actions.quit, function()
149+
stream_ui_elem:map("n", Config.opts.ui.actions.quit, function()
144150
stream_ui_elem:unmount()
145151
end, { noremap = true, silent = true })
146152

147153
-- cancel job if actions.cancel is called
148-
stream_ui_elem:map("n", config.opts.ui.actions.cancel, function()
154+
stream_ui_elem:map("n", Config.opts.ui.actions.cancel, function()
149155
job:shutdown()
150156
end, { noremap = true, silent = true })
151157

0 commit comments

Comments
 (0)