--- lib/db/sqlite3.lua
-- copied from: https://github.com/SinisterRectus/lit-sqlite3
-- see also higher level binding: https://github.com/kkharji/sqlite.lua
-- see also: https://github.com/Playermet/luajit-sqlite/blob/master/sqlite3.lua
-- see also: https://github.com/shakesoda/luajit-sqlite-ffi, fork of: https://github.com/Wiladams/LJIT2SQLite
-- see also: https://github.com/CentauriSoldier/SQLite3-for-Lua
--------------------------------------------------------------------------------
-- A library for interfacing with SQLite3 databases.
--
-- Copyright (C) 2011-2016 Stefano Peluchetti. All rights reserved.
--
-- Features, documentation and more: http://www.scilua.org .
--
-- This file is part of the LJSQLite3 library, which is released under the MIT
-- license: full text in file LICENSE.TXT in the library's root folder.
--------------------------------------------------------------------------------
--[[lit-meta
  name = "SinisterRectus/sqlite3"
  version = "1.0.1"
  homepage = "http://scilua.org/ljsqlite3.html"
  description = "SciLua's sqlite3 bindings repackaged for lit."
  tags = {"database", "sqlite3"}
  license = "MIT"
  author = "Stefano Peluchetti"
  contributors = {
    "Sinister Rectus"
  }
]] --------------------------------------------------------------------------------
-- TODO: Refactor according to latest style / coding guidelines.
--       introduce functionality to get of a defined type to avoid if check?
--       Add extended error codes from Sqlite?
--       Consider type checks?
--       Exposed cdef constants are ok?
--       Result set (and so exec) could be optimized by avoiding loads/stores of row table via _step?
--
local ffi = require "ffi"
local bit = require "bit"
local jit = require "jit"
local sqlite3 = require "db/sqlite3-ffi"
local countOccurrence = require"peg".countOccurrence
local util
local unpack = table.unpack
local sql = sqlite3.lib
local codes = sqlite3.codes

---- xsys replacement ----------------------------------------------------------

local insert = table.insert
local match, gmatch = string.match, string.gmatch

local function split(str, delim)
	local words = {}
	for word in gmatch(str .. delim, '(.-)' .. delim) do
		insert(words, word)
	end
	return words
end

local function trim(str)
	return match(str, '^%s*(.-)%s*$')
end

--------------------------------------------------------------------------------

local function sqlError(code, msg, printError)
	if printError ~= false then
		util = util or require "util"
		return util.printRed("sqlite3 " .. code .. ": " .. msg)
	end
	return "sqlite3 " .. code .. ": " .. msg
end

--------------------------------------------------------------------------------

local transient = ffi.cast("sqlite3_destructor_type", -1)
local int64_ct = ffi.typeof("int64_t")

local blob_mt = {} -- For tagging only.

local function blob(str)
	return setmetatable({str}, blob_mt)
end

local connStmt = {} -- Statements for a conn.
local connCb = {} -- Callbacks for a conn.
local aggregateState = {} -- Aggregate states.

local stmt_step

local stmt_mt, stmt_ct = {}
stmt_mt.__index = stmt_mt

local conn_mt, conn_ct = {}
conn_mt.__index = conn_mt

-- Checks ----------------------------------------------------------------------

-- Helper function to get error msg and code from sqlite.
local function codeMsg(pConn, code)
	return codes[code]:lower(), ffi.string(sql.sqlite3_errmsg(pConn))
end

-- Throw error for a given connection.
local function E_conn(pConn, code, printError)
	local code2, msg = codeMsg(pConn, code)
	return sqlError(code2, msg, printError)
end

-- Test code is OK or throw error for a given connection.
local function T_okCode(pConn, code)
	if code ~= sql.SQLITE_OK then
		return E_conn(pConn, code)
	end
end

local function T_open(x)
	if x._closed then
		return sqlError("misuse", "object is closed")
	end
end

-- see at the end of this file how these setters/getters were generated by code
local function get_column(stmtOrValue, i)
	local t = sql.sqlite3_column_type(stmtOrValue, i)
	if t == sql.SQLITE_INTEGER then
		return tonumber(sql.sqlite3_column_int64(stmtOrValue, i))
	elseif t == sql.SQLITE_FLOAT then
		return sql.sqlite3_column_double(stmtOrValue, i)
	elseif t == sql.SQLITE_TEXT then
		local nb = sql.sqlite3_column_bytes(stmtOrValue, i)
		return ffi.string(sql.sqlite3_column_text(stmtOrValue, i), nb)
	elseif t == sql.SQLITE_BLOB then
		local nb = sql.sqlite3_column_bytes(stmtOrValue, i)
		return ffi.string(sql.sqlite3_column_blob(stmtOrValue, i), nb)
	elseif t == sql.SQLITE_NULL then
		return nil
	else
		return nil, sqlError("constraint", "unexpected SQLite3 type")
	end
end

local function get_value(stmtOrValue)
	local t = sql.sqlite3_value_type(stmtOrValue)
	if t == sql.SQLITE_INTEGER then
		return tonumber(sql.sqlite3_value_int64(stmtOrValue))
	elseif t == sql.SQLITE_FLOAT then
		return sql.sqlite3_value_double(stmtOrValue)
	elseif t == sql.SQLITE_TEXT then
		local nb = sql.sqlite3_value_bytes(stmtOrValue)
		return ffi.string(sql.sqlite3_value_text(stmtOrValue), nb)
	elseif t == sql.SQLITE_BLOB then
		local nb = sql.sqlite3_value_bytes(stmtOrValue)
		return ffi.string(sql.sqlite3_value_blob(stmtOrValue), nb)
	elseif t == sql.SQLITE_NULL then
		return nil
	else
		return nil, sqlError("constraint", "unexpected SQLite3 type")
	end
end

local function set_column(stmtOrValue, v, i)
	local t = type(v)
	if t == "number" then
		return sql.sqlite3_bind_double(stmtOrValue, i, v)
	elseif t == "string" then
		return sql.sqlite3_bind_text(stmtOrValue, i, v, #v, transient)
	elseif t == "table" and getmetatable(v) == blob_mt then
		v = v[1]
		return sql.sqlite3_bind_blob(stmtOrValue, i, v, #v, transient)
	elseif t == "nil" then
		return sql.sqlite3_bind_null(stmtOrValue, i)
	elseif ffi.istype(int64_ct, v) then
		return tonumber(sql.sqlite3_bind_int64(stmtOrValue, i, v))
	else
		return nil, sqlError("constraint", "unexpected Lua type")
	end
end

local function set_value(stmtOrValue, v)
	local t = type(v)
	if t == "number" then
		return sql.sqlite3_result_double(stmtOrValue, v)
	elseif t == "string" then
		return sql.sqlite3_result_text(stmtOrValue, v, #v, transient)
	elseif t == "table" and getmetatable(v) == blob_mt then
		v = v[1]
		return sql.sqlite3_result_blob(stmtOrValue, v, #v, transient)
	elseif t == "nil" then
		return sql.sqlite3_result_null(stmtOrValue)
	elseif ffi.istype(int64_ct, v) then
		return tonumber(sql.sqlite3_result_int64(stmtOrValue, v))
	else
		return nil, sqlError("constraint", "unexpected Lua type")
	end
end

-- Connection ------------------------------------------------------------------
local open_modes = {
	-- ro = bit.bor(sql.SQLITE_OPEN_READONLY, sql.SQLITE_OPEN_SHAREDCACHE),
	-- rw = bit.bor(sql.SQLITE_OPEN_READWRITE, sql.SQLITE_OPEN_SHAREDCACHE),
	-- rwc = bit.bor(sql.SQLITE_OPEN_READWRITE, sql.SQLITE_OPEN_CREATE, sql.SQLITE_OPEN_SHAREDCACHE)
	ro = sql.SQLITE_OPEN_READONLY,
	rw = sql.SQLITE_OPEN_READWRITE,
	rwc = bit.bor(sql.SQLITE_OPEN_READWRITE, sql.SQLITE_OPEN_CREATE)
}

local function shutdown()
	local ret = sql.sqlite3_shutdown()
	if ret ~= 0 then
		util.printError("sqlite shutdown error: %d", tonumber(ret))
	end
end

local function open(str, mode, pass)
	sql.sqlite3_initialize()
	mode = mode or "rwc"
	mode = open_modes[mode]
	if not mode then
		return nil, sqlError("constraint", "argument #2 to open must be ro, rw, or rwc")
	end
	--[[ local code = sql.sqlite3_enable_shared_cache(1)
	if code ~= sql.SQLITE_OK then
		return nil, sqlError("connection", "sqlite3_enable_shared_cache(1) failed")
	end ]]
	local aptr = ffi.new("sqlite3*[1]")
	-- Usually aptr is set even if error code, so conn always needs to be closed.
	local code = sql.sqlite3_open_v2(str, aptr, mode, nil)
	local conn = conn_ct(aptr[0], false)
	-- Must create this anyway due to conn:close() function.
	connStmt[conn] = setmetatable({}, {__mode = "k"}) -- set as weak table, see: https://www.lua.org/pil/17.html
	connCb[conn] = {scalar = {}, step = {}, final = {}}
	if code ~= sql.SQLITE_OK then
		local code2, msg = codeMsg(conn._ptr, code) -- Before closing!
		conn:close() -- Free resources, should not fail here in this case!
		return nil, sqlError(code2, msg)
	end
	if pass then
		code = sql.sqlite3_key(conn._ptr, pass, #pass)
		if code ~= sql.SQLITE_OK then
			local code2, msg = codeMsg(conn._ptr, code) -- Before closing!
			conn:close() -- Free resources, should not fail here in this case!
			return nil, sqlError(code2, msg)
		end
	end
	return conn
end

function conn_mt:close()
	T_open(self)
	-- Close all stmt linked to conn.
	for k, _ in pairs(connStmt[self]) do
		if not k._closed then
			k:close()
		end
	end
	-- Close all callbacks linked to conn.
	for _, v in pairs(connCb[self].scalar) do
		v:free()
	end
	for _, v in pairs(connCb[self].step) do
		v:free()
	end
	for _, v in pairs(connCb[self].final) do
		v:free()
	end
	local code = sql.sqlite3_close(self._ptr)
	T_okCode(self._ptr, code)
	connStmt[self] = nil -- Table connStmt is not weak, need to clear manually.
	connCb[self] = nil
	self._closed = true -- Set only if close succeeded.
end

function conn_mt:__gc()
	if not self._closed then
		self:close()
	end
end

function conn_mt:prepare(stmtStr)
	local err = T_open(self)
	if err then
		return nil, err
	end
	local aptr = ffi.new("sqlite3_stmt*[1]")
	-- If error code aptr NULL, so no need to close anything.
	local code = sql.sqlite3_prepare_v2(self._ptr, stmtStr, #stmtStr, aptr, nil)
	err = T_okCode(self._ptr, code)
	if err then
		return nil, err
	end
	local stmt = stmt_ct(aptr[0], false, self._ptr, code)
	connStmt[self][stmt] = true
	return stmt
end

-- Connection exec, __call, rowExec --------------------------------------------
function conn_mt:exec(commands) -- , get)
	local err = T_open(self)
	if err then
		return nil, 0, err
	end
	local cmd1
	if commands:sub(-1) == ";" then
		cmd1 = split(commands:sub(1, -2), ";")
	else
		cmd1 = split(commands, ";")
	end
	if #cmd1 > 1 then
		local count
		for i = #cmd1, 2, -1 do
			count = countOccurrence(cmd1[i], "'")
			if count % 2 == 1 then
				cmd1[#cmd1 - 1] = cmd1[#cmd1 - 1] .. cmd1[i]
				table.remove(cmd1, i)
			end
		end
	end
	local stmt -- res, n, stmt
	for i = 1, #cmd1 do
		local cmd = trim(cmd1[i])
		if #cmd > 0 then
			stmt, err = self:prepare(cmd)
			if err then
				return stmt, err -- stmt is nil
			end
			if stmt._code == sql.SQLITE_DONE then
				stmt:close()
				return nil
			end
			return stmt
			-- res, n, err = stmt:resultSet()
			-- stmt:close()
		end
	end
	return nil, "sql command was empty" -- res, n -- Only last record is returned.
end

function conn_mt:rowExec(command)
	T_open(self)
	local stmt = self:prepare(command)
	local res = stmt:_step()
	if stmt:_step() then
		return nil, sqlError("misuse", "multiple records returned, 1 expected")
	end
	stmt:close()
	if res then
		return unpack(res)
	else
		return nil
	end
end

function conn_mt:__call(commands, out)
	T_open(self)
	out = out or print
	local cmd1 = split(commands, ";")
	for c = 1, #cmd1 do
		local cmd = trim(cmd1[c])
		if #cmd > 0 then
			local stmt = self:prepare(cmd)
			local ret, n = stmt:resultSet()
			if ret then -- All the results get handled, not only last one.
				out(unpack(ret[0])) -- Headers are printed.
				for i = 1, n do
					local o = {}
					for j = 1, #ret[0] do
						local v = ret[j][i]
						if type(v) == "nil" then
							v = ""
						end -- Empty strings for NULLs.
						o[#o + 1] = tostring(v)
					end
					out(unpack(o))
				end
			end
			stmt:close()
		end
	end
end

-- Callbacks -------------------------------------------------------------------
-- Update (one of) callbacks registry for sqlite functions.
local function updateCb(self, where, name, f)
	local cbs = connCb[self][where]
	if cbs[name] then -- Callback already present, free old one.
		cbs[name]:free()
	end
	cbs[name] = f -- Could be nil and that's fine.
end

-- Return manually casted callback that sqlite expects, scalar.
local function scalarCb(name, f)
	local values = {} -- Conversion buffer.
	local function sqlf(context, nvalues, pvalues)
		-- Indexing 0,N-1.
		for i = 1, nvalues do
			values[i] = get_value(pvalues[i - 1])
		end
		-- Throw error via sqlite function if necessary.
		local ok, result = pcall(f, unpack(values, 1, nvalues))
		if not ok then
			local msg = "Lua registered scalar function " .. name .. " error: " .. result
			sql.sqlite3_result_error(context, msg, #msg)
		else
			set_value(context, result)
		end
	end
	return ffi.cast("sqlite3_cbstep", sqlf) -- defined in sqlite3-ffi.lua
end

-- Return the state for aggregate case (created via initstate()). We use the ptr
-- returned from aggregate_context() for tagging only, all the state data is
-- handled from Lua side.
local function getstate(context, initstate, size)
	-- Only pointer address relevant for indexing, size irrelevant.
	local ptr = sql.sqlite3_aggregate_context(context, size)
	local pid = tonumber(ffi.cast("intptr_t", ptr))
	local state = aggregateState[pid]
	if type(state) == "nil" then
		state = initstate()
		aggregateState[pid] = state
	end
	return state, pid
end

-- Return manually casted callback that sqlite expects, stepper for aggregate.
local function stepcb(name, f, initstate)
	local values = {} -- Conversion buffer.
	local function sqlf(context, nvalues, pvalues)
		-- Indexing 0,N-1.
		for i = 1, nvalues do
			values[i] = get_value(pvalues[i - 1])
		end
		local state = getstate(context, initstate, 1)
		-- Throw error via sqlite function if necessary.
		local ok, result = pcall(f, state, unpack(values, 1, nvalues))
		if not ok then
			local msg = "Lua registered step function " .. name .. " error: " .. result
			sql.sqlite3_result_error(context, msg, #msg)
		end
	end
	return ffi.cast("sqlite3_cbstep", sqlf) -- defined in sqlite3-ffi.lua
end

-- Return manually casted callback that sqlite expects, finalizer for aggregate.
local function finalcb(name, f, initstate)
	local function sqlf(context)
		local state, pid = getstate(context, initstate, 0)
		aggregateState[pid] = nil -- Clear the state.
		local ok, result = pcall(f, state)
		-- Throw error via sqlite function if necessary.
		if not ok then
			local msg = "Lua registered final function " .. name .. " error: " .. result
			sql.sqlite3_result_error(context, msg, #msg)
		else
			set_value(context, result)
		end
	end
	return ffi.cast("sqlite3_cbfinal", sqlf) -- defined in sqlite3-ffi.lua
end

function conn_mt:setscalar(name, f)
	T_open(self)
	jit.off(stmt_step) -- Necessary to avoid bad calloc in some use cases.
	local cbf = f and scalarCb(name, f) or nil
	local code = sql.sqlite3_create_function(self._ptr, name, -1, 5, nil, cbf, nil, nil) -- If cbf nil this clears the function is sqlite.
	T_okCode(self._ptr, code)
	updateCb(self, "scalar", name, cbf) -- Update and clear old.
end

function conn_mt:setaggregate(name, initstate, step, final)
	T_open(self)
	jit.off(stmt_step) -- Necessary to avoid bad calloc in some use cases.
	local cbs = step and stepcb(name, step, initstate) or nil
	local cbf = final and finalcb(name, final, initstate) or nil
	local code = sql.sqlite3_create_function(self._ptr, name, -1, 5, nil, nil, cbs, cbf) -- If cbs, cbf nil this clears the function is sqlite.
	T_okCode(self._ptr, code)
	updateCb(self, "step", name, cbs) -- Update and clear old.
	updateCb(self, "final", name, cbf) -- Update and clear old.
end

conn_ct = ffi.metatype("struct { sqlite3* _ptr; bool _closed; }", conn_mt)

-- Statement -------------------------------------------------------------------
function stmt_mt:reset()
	T_open(self)
	-- Ignore possible error code, it would be repetition of error raised during
	-- most recent evaluation of statement which would have been raised already.
	sql.sqlite3_reset(self._ptr)
	self._code = sql.SQLITE_OK -- Always succeds.
	return self
end

function stmt_mt:close()
	T_open(self)
	-- Ignore possible error code, it would be repetition of error raised during
	-- most recent evaluation of statement which would have been raised already.
	sql.sqlite3_finalize(self._ptr)
	self._code = sql.SQLITE_OK -- Always succeds.
	self._closed = true -- Must be called exaclty once.
end

function stmt_mt:__gc()
	if not self._closed then
		self:close()
	end
end

-- Statement step, resultSet ---------------------------------------------------
function stmt_mt:_ncol()
	return sql.sqlite3_column_count(self._ptr)
end

function stmt_mt:_header(h)
	for i = 1, self:_ncol() do -- Here indexing 0,N-1.
		h[i] = ffi.string(sql.sqlite3_column_name(self._ptr, i - 1))
	end
end

stmt_step = function(self, row, header)
	-- Must check code ~= SQL_DONE or sqlite3_step --> undefined result.
	if self._code == sql.SQLITE_DONE then
		return nil
	end -- Already finished.
	self._code = sql.sqlite3_step(self._ptr)
	if self._code == sql.SQLITE_ROW then
		-- All the sql.* functions called never errors here.
		row = row or {}
		for i = 1, self:_ncol() do
			row[i] = get_column(self._ptr, i - 1)
		end
		if header then
			self:_header(header)
		end
		return row, header
	elseif self._code == sql.SQLITE_DONE then -- Have finished now.
		return nil
	else -- If code not DONE or ROW then it's error.
		E_conn(self._conn, self._code)
	end
end
stmt_mt._step = stmt_step

function stmt_mt:step(row, header)
	T_open(self)
	return self:_step(row, header)
end

function stmt_mt:resultSet(maxrecords) -- get, maxrecords
	T_open(self)
	-- get = get or "hik"
	maxrecords = maxrecords or math.huge
	if maxrecords < 1 then
		return nil, 0, sqlError("constraint", "agument #1 to resultSet must be >= 1")
	end
	if self._code == sql.SQLITE_DONE then
		return nil
	end
	-- local hash, hasi, hask = get:find("h"), get:find("i"), get:find("k")
	local r, h = self:_step({}, {})
	if not r then
		return nil, 0
	end -- No records case.

	-- First record, o is a temporary table used to get records.
	local o = {} -- hash and {[0] = h} or {}
	for i = 1, #h do
		o[i] = {r[i]}
	end
	-- Other records.
	local n = 1
	while n < maxrecords and self:_step(r) do
		n = n + 1
		for i = 1, #h do
			o[i][n] = r[i]
		end
	end

	-- local out = {}
	--[=[
	local out = {[0] = o[0]} -- Eventually copy colnames.
	if hasi then -- Use numeric indexes.
		for i = 1, #h do
			out[i] = o[i]
		end
	end
	if hask then -- Use colnames indexes.
		for i = 1, #h do
			out[h[i]] = o[i]
		end
	end
	--[=[ local out = {}
	for i = 1, #o[1] do
		for j = 1, #h do
			if j == 1 then
				out[i] = {}
			end
			out[i][h[j]] = o[j][i]
		end
	end ]=]
	-- return out, n
	return o, n
end

-- Statement bind --------------------------------------------------------------
function stmt_mt:_bind1(i, v)
	local code = set_column(self._ptr, v, i) -- Here indexing 1,N.
	T_okCode(self._conn, code)
	return self
end

function stmt_mt:bind1(i, v)
	T_open(self)
	return self:_bind1(i, v)
end

function stmt_mt:bind(...)
	T_open(self)
	for i = 1, select("#", ...) do
		self:_bind1(i, select(i, ...))
	end
	return self
end

function stmt_mt:clearbind()
	T_open(self)
	local code = sql.sqlite3_clear_bindings(self._ptr)
	T_okCode(self._conn, code)
	return self
end

stmt_ct = ffi.metatype([[struct {
  sqlite3_stmt* _ptr;
  bool          _closed;
  sqlite3*      _conn;
  int32_t       _code;
}]], stmt_mt)

return {sql = sql, get_column = get_column, E_conn = E_conn, open = open, blob = blob, shutdown = shutdown}

--[[ -- code to generate setters/getters:
-- Getters / Setters to minimize code duplication ------------------------------
local sql_get_code = [=[
return function(stmtOrValue <opt_i>)
  local t = sql.sqlite3_<variant>_type(stmtOrValue <opt_i>)
  if t == sql.SQLITE_INTEGER then
    return tonumber(sql.sqlite3_<variant>_int64(stmtOrValue <opt_i>))
  elseif t == sql.SQLITE_FLOAT then
    return sql.sqlite3_<variant>_double(stmtOrValue <opt_i>)
  elseif t == sql.SQLITE_TEXT then
    local nb = sql.sqlite3_<variant>_bytes(stmtOrValue <opt_i>)
    return ffi.string(sql.sqlite3_<variant>_text(stmtOrValue <opt_i>), nb)
  elseif t == sql.SQLITE_BLOB then
    local nb = sql.sqlite3_<variant>_bytes(stmtOrValue <opt_i>)
    return ffi.string(sql.sqlite3_<variant>_blob(stmtOrValue <opt_i>), nb)
  elseif t == sql.SQLITE_NULL then
    return nil
  else
    return nil, sqlError("constraint", "unexpected SQLite3 type")
  end
end
]=]

local sql_set_code = [=[
return function(stmtOrValue, v <opt_i>)
  local t = type(v)
  if ffi.istype(int64_ct, v) then
    return tonumber(sql.sqlite3_<variant>_int64(stmtOrValue <opt_i>, v))
  elseif t == "number" then
    return sql.sqlite3_<variant>_double(stmtOrValue <opt_i>, v)
  elseif t == "string" then
    return sql.sqlite3_<variant>_text(stmtOrValue <opt_i>, v, #v,
      transient)
  elseif t == "table" and getmetatable(v) == blob_mt then
    v = v[1]
    return sql.sqlite3_<variant>_blob(stmtOrValue <opt_i>, v, #v,
      transient)
  elseif t == "nil" then
    return sql.sqlite3_<variant>_null(stmtOrValue <opt_i>)
  else
    return nil, sqlError("constraint", "unexpected Lua type")
  end
end
]=]

-- Environment for setters/getters.
local sql_env = {sql = sql, transient = transient, ffi = ffi, int64_ct = int64_ct, blob_mt = blob_mt, getmetatable = getmetatable, tonumber = tonumber, sqlError = sqlError, type = type}

local function sql_format(s, variant, index)
	return s:gsub("<variant>", variant):gsub("<opt_i>", index)
end

local function loadcode(s, env)
	local ret = assert(loadstring(s))
	if env then
		setfenv(ret, env)
	end
	return ret()
end

-- Must always be called from *:_* function due to error level 4.
local get_column = loadcode(sql_format(sql_get_code, "column", ",i"), sql_env)
local get_value = loadcode(sql_format(sql_get_code, "value", "  "), sql_env)
local set_column = loadcode(sql_format(sql_set_code, "bind", ",i"), sql_env)
local set_value = loadcode(sql_format(sql_set_code, "result", "  "), sql_env)
]]
