Skip to content

Optimize and improve request module #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 229 additions & 183 deletions src/wsapi/request.lua
Original file line number Diff line number Diff line change
@@ -1,252 +1,298 @@
local util = require"wsapi.util"
local util = require "wsapi.util"

local _M = {}
local methods = {}

-- Cache of frequently used functions
local string_match = string.match
local string_find = string.find
local string_sub = string.sub
local string_gmatch = string.gmatch
local string_gsub = string.gsub
local string_lower = string.lower
local table_insert = table.insert
local table_concat = table.concat
local tonumber = tonumber
local pairs = pairs
local type = type
local tostring = tostring
local error = error
local setmetatable = setmetatable
local rawset = rawset

-- Constants
local EOH = "\r\n\r\n"
local BOUNDARY_PATTERN = "boundary%=(.-)$"
local DISPOSITION_PATTERN = ';%s*([^%s=]+)="(.-)"'
local HEADER_PATTERN = '([^%c%s:]+):%s+([^\n]+)'
local QUERY_PATTERN = "([^&=]+)=([^&=]*)&?"
local FILENAME_PATTERN = "[/\\]?([^/\\]+)$"
local EMPTY_STRING_PATTERN = "^%s*$"
local COOKIE_SEPARATOR_PATTERN = "%s*;%s*"

-- Local auxiliary functions
local function split_filename(path)
local name_patt = "[/\\]?([^/\\]+)$"
return (string.match(path, name_patt))
return string_match(path, FILENAME_PATTERN)
end

local function insert_field (tab, name, value, overwrite)
if overwrite or not tab[name] then
tab[name] = value
else
local t = type (tab[name])
if t == "table" then
table.insert (tab[name], value)
local function insert_field(tab, name, value, overwrite)
if overwrite or not tab[name] then
tab[name] = value
else
tab[name] = { tab[name], value }
local t = type(tab[name])
if t == "table" then
table_insert(tab[name], value)
else
tab[name] = {tab[name], value}
end
end
end
end

local function parse_qs(qs, tab, overwrite)
tab = tab or {}
if type(qs) == "string" then
local url_decode = util.url_decode
for key, val in string.gmatch(qs, "([^&=]+)=([^&=]*)&?") do
insert_field(tab, url_decode(key), url_decode(val), overwrite)
tab = tab or {}
if type(qs) == "string" then
local url_decode = util.url_decode
for key, val in string_gmatch(qs, QUERY_PATTERN) do
insert_field(tab, url_decode(key), url_decode(val), overwrite)
end
elseif qs then
error("WSAPI Request error: invalid query string")
end
elseif qs then
error("WSAPI Request error: invalid query string")
end
return tab
return tab
end

local function get_boundary(content_type)
local boundary = string.match(content_type, "boundary%=(.-)$")
return "--" .. tostring(boundary)
if not content_type then return nil end
local boundary = string_match(content_type, BOUNDARY_PATTERN)
return boundary and "--" .. tostring(boundary) or nil
end

local function break_headers(header_data)
local headers = {}
for type, val in string.gmatch(header_data, '([^%c%s:]+):%s+([^\n]+)') do
type = string.lower(type)
headers[type] = val
end
return headers
local headers = {}
for htype, val in string_gmatch(header_data, HEADER_PATTERN) do
headers[string_lower(htype)] = val
end
return headers
end

local function read_field_headers(input, pos)
local EOH = "\r\n\r\n"
local s, e = string.find(input, EOH, pos, true)
if s then
return break_headers(string.sub(input, pos, s-1)), e+1
else return nil, pos end
local s, e = string_find(input, EOH, pos, true)
if s then
return break_headers(string_sub(input, pos, s - 1)), e + 1
end
return nil, pos
end

local function get_field_names(headers)
local disp_header = headers["content-disposition"] or ""
local attrs = {}
for attr, val in string.gmatch(disp_header, ';%s*([^%s=]+)="(.-)"') do
attrs[attr] = val
end
return attrs.name, attrs.filename and split_filename(attrs.filename)
local disp_header = headers["content-disposition"] or ""
local attrs = {}
for attr, val in string_gmatch(disp_header, DISPOSITION_PATTERN) do
attrs[attr] = val
end
return attrs.name, attrs.filename and split_filename(attrs.filename)
end

local function read_field_contents(input, boundary, pos)
local boundaryline = "\r\n" .. boundary
local s, e = string.find(input, boundaryline, pos, true)
if s then
return string.sub(input, pos, s-1), s-pos, e+1
else return nil, 0, pos end
local boundaryline = "\r\n" .. boundary
local s, e = string_find(input, boundaryline, pos, true)
if s then
return string_sub(input, pos, s - 1), s - pos, e + 1
end
return nil, 0, pos
end

local function file_value(file_contents, file_name, file_size, headers)
local value = { contents = file_contents, name = file_name,
size = file_size }
for h, v in pairs(headers) do
if h ~= "content-disposition" then
value[h] = v
local value = {
contents = file_contents,
name = file_name,
size = file_size
}

for h, v in pairs(headers) do
if h ~= "content-disposition" then
value[h] = v
end
end
end
return value
return value
end

local function fields(input, boundary)
local state, _ = { }
_, state.pos = string.find(input, boundary, 1, true)
state.pos = state.pos + 1
return function (state, _)
local headers, name, file_name, value, size
headers, state.pos = read_field_headers(input, state.pos)
if headers then
name, file_name = get_field_names(headers)
if file_name then
value, size, state.pos = read_field_contents(input, boundary,
state.pos)
value = file_value(value, file_name, size, headers)
else
value, size, state.pos = read_field_contents(input, boundary,
state.pos)
end
end
return name, value
end, state
local state = {}
_, state.pos = string_find(input, boundary, 1, true)
state.pos = state.pos + 1
return function(state)
local headers, name, file_name, value, size
headers, state.pos = read_field_headers(input, state.pos)
if not headers then return nil end

name, file_name = get_field_names(headers)
value, size, state.pos = read_field_contents(input, boundary, state.pos)
if file_name then
value = file_value(value, file_name, size, headers)
end
return name, value
end, state
end

local function parse_multipart_data(input, input_type, tab, overwrite)
tab = tab or {}
local boundary = get_boundary(input_type)
for name, value in fields(input, boundary) do
insert_field(tab, name, value, overwrite)
end
return tab
tab = tab or {}
local boundary = get_boundary(input_type)
if not boundary then return tab end

for name, value in fields(input, boundary) do
insert_field(tab, name, value, overwrite)
end
return tab
end

local function parse_post_data(wsapi_env, tab, overwrite)
tab = tab or {}
local input_type = wsapi_env.CONTENT_TYPE
if string.find(input_type, "x-www-form-urlencoded", 1, true) then
local length = tonumber(wsapi_env.CONTENT_LENGTH) or 0
parse_qs(wsapi_env.input:read(length) or "", tab, overwrite)
elseif string.find(input_type, "multipart/form-data", 1, true) then
tab = tab or {}
local input_type = wsapi_env.CONTENT_TYPE or ""
local input = wsapi_env.input.read
local length = tonumber(wsapi_env.CONTENT_LENGTH) or 0
if length > 0 then
parse_multipart_data(wsapi_env.input:read(length) or "", input_type, tab, overwrite)
end
else
local length = tonumber(wsapi_env.CONTENT_LENGTH) or 0
tab.post_data = wsapi_env.input:read(length) or ""
end
return tab
end

_M.methods = {}

local methods = _M.methods

function methods.__index(tab, name)
local func
if methods[name] then
func = methods[name]
else
local route_name = name:match("link_([%w_]+)")
if route_name then
func = function (self, query, ...)
return tab:route_link(route_name, query, ...)
end

if string_find(input_type, "x%-www%-form%-urlencoded", 1, true) then
parse_qs(input(length) or "", tab, overwrite)
elseif string_find(input_type, "multipart/form%-data", 1, true) and length > 0 then
parse_multipart_data(input(length) or "", input_type, tab, overwrite)
elseif length > 0 then
tab.post_data = input(length) or ""
end
end
tab[name] = func
return func

return tab
end

-- Request methods
function methods:qs_encode(query, url)
local parts = {}
for k, v in pairs(query or {}) do
parts[#parts+1] = k .. "=" .. util.url_encode(v)
end
if #parts > 0 then
return (url and (url .. "?") or "") .. table.concat(parts, "&")
else
return (url and url or "")
end
if not query or not next(query) then
return url or ""
end

local parts = {}
for k, v in pairs(query) do
table_insert(parts, k .. "=" .. util.url_encode(v))
end

return (url and (url .. "?") or "") .. table_concat(parts, "&")
end

function methods:route_link(route, query, ...)
local builder = self.mk_app["link_" .. route]
if builder then
local builder = self.mk_app and self.mk_app["link_" .. route]
if not builder then
error("there is no route named " .. route)
end

local uri = builder(self.mk_app, self.env, ...)
local qs = self:qs_encode(query)
return uri .. (qs ~= "" and ("?"..qs) or "")
else
error("there is no route named " .. route)
end
return uri .. (qs ~= "" and ("?" .. qs) or "")
end

function methods:link(url, query)
local prefix = (self.mk_app and self.mk_app.prefix) or self.script_name
local uri = prefix .. url
local qs = self:qs_encode(query)
return prefix .. url .. (qs ~= "" and ("?"..qs) or "")
local prefix = (self.mk_app and self.mk_app.prefix) or self.script_name
return self:qs_encode(query, prefix .. url)
end

function methods:absolute_link(url, query)
local qs = self:qs_encode(query)
return url .. (qs ~= "" and ("?"..qs) or "")
return self:qs_encode(query, url)
end

function methods:static_link(url)
local prefix = (self.mk_app and self.mk_app.prefix) or self.script_name
local is_script = prefix:match("(%.%w+)$")
if not is_script then return self:link(url) end
local vpath = prefix:match("(.*)/") or ""
return vpath .. url
local prefix = (self.mk_app and self.mk_app.prefix) or self.script_name
local is_script = string_match(prefix, "(%.%w+)$")

if not is_script then
return self:link(url)
end

local vpath = string_match(prefix, "(.*)/") or ""
return vpath .. url
end

function methods:empty(s)
return not s or string.match(s, "^%s*$")
return not s or string_match(s, EMPTY_STRING_PATTERN)
end

function methods:empty_param(param)
return self:empty(self.params[param])
return self:empty(self.params[param])
end

function _M.new(wsapi_env, options)
options = options or {}
local req = {
GET = {},
POST = {},
method = wsapi_env.REQUEST_METHOD,
path_info = wsapi_env.PATH_INFO,
query_string = wsapi_env.QUERY_STRING,
script_name = wsapi_env.SCRIPT_NAME,
env = wsapi_env,
mk_app = options.mk_app,
doc_root = wsapi_env.DOCUMENT_ROOT,
app_path = wsapi_env.APP_PATH
}
parse_qs(wsapi_env.QUERY_STRING, req.GET, options.overwrite)
if options.delay_post then
req.parse_post = function (self)
parse_post_data(wsapi_env, self.POST, options.overwrite)
self.parse_post = function () return nil, "postdata already parsed" end
return self.POST
options = options or {}

-- Default values ​​for WSAPI environment
local env_defaults = {
REQUEST_METHOD = "GET",
CONTENT_TYPE = "",
QUERY_STRING = "",
PATH_INFO = "",
SCRIPT_NAME = "",
APP_PATH = "",
DOCUMENT_ROOT = "",
input = {read = function() return "" end}
}

-- Applies default values
for k, v in pairs(env_defaults) do
if not wsapi_env[k] then
wsapi_env[k] = v
end
end

-- Creates the request object
local req = {
GET = {},
POST = {},
method = wsapi_env.REQUEST_METHOD,
path_info = wsapi_env.PATH_INFO,
query_string = wsapi_env.QUERY_STRING,
script_name = wsapi_env.SCRIPT_NAME,
env = wsapi_env,
mk_app = options.mk_app,
doc_root = wsapi_env.DOCUMENT_ROOT,
app_path = wsapi_env.APP_PATH
}

-- Parse request data
parse_qs(wsapi_env.QUERY_STRING, req.GET, options.overwrite)

if options.delay_post then
req.parse_post = function(self)
parse_post_data(wsapi_env, self.POST, options.overwrite)
self.parse_post = function() return nil, "postdata already parsed" end
return self.POST
end
else
parse_post_data(wsapi_env, req.POST, options.overwrite)
req.parse_post = function() return nil, "postdata already parsed" end
end
else
parse_post_data(wsapi_env, req.POST, options.overwrite)
req.parse_post = function () return nil, "postdata already parsed" end
end
req.params = {}
setmetatable(req.params, { __index = function (tab, name)
local var = req.GET[name] or req.POST[name]
rawset(tab, name, var)
return var
end})
req.cookies = {}
local cookies = string.gsub(";" .. (wsapi_env.HTTP_COOKIE or "") .. ";",
"%s*;%s*", ";")
setmetatable(req.cookies, { __index = function (tab, name)
name = name
local pattern = ";" .. name ..
"=(.-);"
local cookie = string.match(cookies, pattern)
cookie = util.url_decode(cookie)
rawset(tab, name, cookie)
return cookie
end})
return setmetatable(req, methods)
end

return _M

-- Combined parameters (GET + POST)
req.params = setmetatable({}, {
__index = function(tab, name)
local var = req.GET[name] or req.POST[name]
rawset(tab, name, var)
return var
end
})

-- Cookies
req.cookies = setmetatable({}, {
__index = function(tab, name)
local cookies = string_gsub(";" .. (wsapi_env.HTTP_COOKIE or "") .. ";",
COOKIE_SEPARATOR_PATTERN, ";")
local pattern = ";" .. name .. "=(.-);"
local cookie = string_match(cookies, pattern)
cookie = cookie and util.url_decode(cookie)
rawset(tab, name, cookie)
return cookie
end
})

return setmetatable(req, {__index = methods})
end

return _M