-- lib/db/call-rest.lua
local util = require "util"
local json = require "json"
local dconn = require "dconn"
local dprf = require "dprf"
local socket = require "system/socket"
local peg = require "peg"
local dt = require "dt"
local coro = require "coro"
local rest4d = require "db/database-rest4d"
local l = require"lang".l
local base64 = require "base64" -- base64lua for correct answer
local print = util.print

local callRest = {}

local useCoro = coro.useCoro()
local loc = {connectPrf = {}, result_debug_length = 3000}

local sleepMillisecond = 10 -- about 1 tick
local doDisconnect
local defaultCallMethod = "POST"
local defaultCallPath = "/rest/nc/"
local defaultBasicAuthorization = nil -- "bWFuYWdlYXBwOndlYm1hbmFnZXJpMy0="
local httpStart = [[
{{call_method}} {{call_path}}{{call}} HTTP/1.1
Host: {{host}}
{{basic_authorization}}
Connection: keep-alive]]
-- User-Agent: nc-rest-call
-- Connection: keep-alive

local function printError(txt, details)
	local err = "connection: " .. dconn.info() .. ", " .. txt
	if details then
		err = err .. "\n" .. details:sub(1, 400)
	end
	util.printRed("\n *** socket error: " .. err)
	return err
end

local function restDisconnect(conn)
	if conn.driver_conn and conn.driver_conn.socket then
		conn.driver_conn.socket:close()
		conn.driver_conn.socket = nil
	end
	if conn.rest_connect_count == nil or conn.rest_connect_count == 1 then
		print(l "call worker has been closed")
	end
end
callRest.disconnect = restDisconnect

local function timeout(conn)
	return rest4d and rest4d.timeout(conn) or conn and conn.timeout or 30
end

local function readHeader(result)
	local headerRead = false
	local ret = ""
	local recv
	local err
	local errTxt
	local sleepCount = 0
	local conn = result.conn
	local startTime
	local receiveSocket = conn.driver_conn and conn.driver_conn.socket
	repeat
		recv, err = receiveSocket:receive("\r\n\r\n")
		if not recv and err and err >= -1 then
			if not startTime then
				startTime = util.seconds()
			end
			if not useCoro then
				util.sleep(sleepMillisecond)
			end
			sleepCount = sleepCount + 1
			-- if sleepCount%10 == 0 then
			-- print(sleepCount..". readHeader socket receive sleep")
			-- end
			if util.seconds(startTime) >= timeout(conn) then
				err = -2 -- end loop
				restDisconnect(conn)
				errTxt = l("call worker read header socket receive timeout > %d seconds, connection was closed", timeout(conn))
			end
		else
			ret = ret .. recv
			local s, e = ret:find("Content-Length: ", 1, true)
			if not s then
				s, e = ret:find("content-length: ", 1, true)
			end
			if s then
				local str = ret:sub(e + 1)
				s = str:find("\r\n", 1, true)
				result.contentLength = tonumber(str:sub(1, s - 1))
			end
			result.header = ret:sub(1, -4) -- we did wait for "\r\n\r\n" so full header is read
			result.data = ""
			headerRead = true
		end
	until headerRead or err and err < -1
	result.chunked = result.header and (peg.found(result.header, "Transfer-Encoding: chunked") or peg.found(result.header, "transfer-encoding: chunked"))
	if not err and not (result.chunked or result.contentLength) then
		return printError("call worker read header error: result has no content-length and is not transfer-encoding: chunked")
	end
	if result.chunked then
		result.contentLength = math.huge
	end
	if err and (type(err) == "string" or err < 1) then
		return errTxt or tostring(err)
	end
	return nil
end

local function readContent(result)
	local dataLength
	if result.data == nil then
		result.data = ""
	end
	if result.chunked and result.data:sub(-7) == "\r\n0\r\n\r\n" then -- no content-length header and last chunk is 0
		result.contentLength = #result.data
	end
	if result.contentLength and #result.data < result.contentLength then
		local recv, err
		local sleepCount = 0
		local conn = result.conn
		local startTime
		local receiveSocket = conn.driver_conn and conn.driver_conn.socket
		repeat
			if result.chunked then -- no content-length header and last chunk is 0
				recv, err, dataLength = receiveSocket:receive(0) -- 0 means: receive as much as you can get
			else
				recv, err, dataLength = receiveSocket:receive(result.contentLength - #result.data)
			end
			if not recv and dataLength >= -1 then
				if not startTime then
					startTime = util.seconds()
				end
				if not useCoro then
					util.sleep(sleepMillisecond)
				end
				sleepCount = sleepCount + 1
				if sleepCount == 1 or sleepCount % 10 == 0 then
					util.print(" rest call socket receive read loop %d, connection: %s", sleepCount, result.conn and result.conn.info or "unknown")
				end
				if util.seconds(startTime) >= timeout(conn) then
					dataLength = -2 -- end loop
					restDisconnect(conn)
					print(l "call worker read content socket receive timeout, connection was closed")
				end
			else
				result.data = result.data .. recv
			end
			if result.chunked and result.data:sub(-7) == "\r\n0\r\n\r\n" then -- no content-length header and last chunk is 0
				result.contentLength = #result.data
			end
		until #result.data >= result.contentLength or err or dataLength < -1
		if err then
			printError(l("call worker read content socket receive error: %s", tostring(err)))
		end
	end
	if result.contentLength and #result.data > result.contentLength then
		result.nextdata = result.data:sub(result.contentLength + 1) -- TODO: use saved nextdata, rename to nextData
		result.data = result.data:sub(1, result.contentLength)
	end
	if result.chunked then
		local tbl = peg.splitToArray(result.data, "\r\n")
		local ret = {}
		for i = 1, #tbl, 2 do
			local len = tonumber(tbl[i], 16) -- hex to dec
			if len and len > 0 then
				if len ~= #tbl[i + 1] then
					printError(l("rest call result data chunk %d length %d number is not same as chunk data length %d", i, len, #tbl[i + 1]))
				end
				ret[#ret + 1] = tbl[i + 1]
			end
		end
		result.data = table.concat(ret)
	end
end

local function sendCall(callStr, result, conn)
	if type(callStr) ~= "string" then
		local err = printError(l("call worker error, call string type '%s' is not a string", type(callStr)))
		return err
	elseif #callStr == 0 then
		local err = printError("call worker error, call string length is zero")
		return err
	end
	local sendSocket = conn.driver_conn and conn.driver_conn.socket
	if sendSocket == nil then
		local err = printError("call worker error, send socket is not connected")
		return err
	end
	if conn.debug then
		util.print("sending rest call: '%s'", callStr)
	end
	local bytesSent = sendSocket:send(callStr)
	if bytesSent ~= #callStr then
		local err = printError(l("call worker error '%s' in socket send, bytes to send: %d", tostring(bytesSent), #callStr))
		return err
	end
	local err = readHeader(result)
	if err == nil then -- all ok
		readContent(result) -- ret is result length in bytes
	end
end

local function call(query, result, paramTbl)
	local conn = result.conn
	local callFunction = paramTbl.call_url or ""
	if paramTbl.query_parameter and peg.found(callFunction, "{{") then
		for key, val in pairs(paramTbl.query_parameter) do
			if peg.found(callFunction, "{{" .. tostring(key) .. "}}") then
				if type(val) == "table" then
					if val[1] and type(val[1]) ~= "string" then
						val = table.concat(val, ",")
					else
						val = "'" .. table.concat(val, "','") .. "'"
					end
				end
				-- products?per_page={{per_page}}&offset={{offset}} -> products?per_page=100&offset=1
				callFunction = peg.replace(callFunction, "{{" .. tostring(key) .. "}}", tostring(val))
			end
		end
	end
	local callStr = peg.replace(httpStart, "{{host}}", peg.parseAfter(paramTbl.host.host, "://"))
	callStr = peg.replace(callStr, "{{call_method}}", paramTbl.call_method or defaultCallMethod)
	callStr = peg.replace(callStr, "{{call_path}}", paramTbl.call_path or defaultCallPath)
	callStr = peg.replace(callStr, "{{call}}", callFunction)
	local authString
	if paramTbl.bearer_authorization then
		callStr = peg.replace(callStr, "{{basic_authorization}}", "Authorization: " .. paramTbl.bearer_authorization)
	elseif paramTbl.basic_authorization or defaultBasicAuthorization then
		authString = (paramTbl.basic_authorization or defaultBasicAuthorization)
		callStr = peg.replace(callStr, "{{basic_authorization}}", "Authorization: " .. authString)
	else
		callStr = peg.replace(callStr, "{{basic_authorization}}\n", "")
	end
	local content = paramTbl.content
	if type(content) == "table" then
		content = json.toJsonRaw(content)
	end
	callStr = peg.replace(callStr, "\n", "\r\n")
	if content and #content > 0 then
		if paramTbl.header and paramTbl.header["content-type"] then
			callStr = callStr .. "\r\nContent-Type: " .. paramTbl.header["content-type"]
		else
			callStr = callStr .. "\r\nContent-Type: application/json"
		end
		callStr = callStr .. "\r\nContent-Length: " .. tostring(#content) .. "\r\n\r\n" .. content -- add content-length + content
	else
		callStr = callStr .. "\r\n\r\n"
	end
	if paramTbl.debug then
		if authString then
			util.printInfo("\n* rest call '%s':\n'%s'\n", query.name, peg.replace(callStr, authString, "XXXXXX"))
		else
			util.printInfo("\n* rest call '%s':\n'%s'\n", query.name, callStr)
		end
	end
	local err = sendCall(callStr, result, conn)
	return err
end

local function sendCallFunction(query, result, conn)
	return function(callParam)
		local paramTbl = {header = callParam.header, content = callParam.post_data, host = {host = conn.host}, call_method = conn.call_method, call_path = conn.call_path}
		local err = call(query, result, paramTbl)
		return result, err
	end
end

local function restConnect(connectPrf, conn)
	local connectSocket
	loc.connectPrf = connectPrf -- save for error messages
	if connectPrf and type(connectPrf[1]) ~= "table" then
		return restConnect({connectPrf}, conn) -- recursive call, connectPrf needs to be an array
	end
	doDisconnect = false
	for _, rec in ipairs(connectPrf) do
		connectSocket = socket.connect(rec.host, rec.port, "tcp", rec.connect_timeout)
		if connectSocket then
			doDisconnect = rec.disconnect
			if conn then
				conn.rest_connect_count = (conn.rest_connect_count or 0) + 1
			end
			if conn == nil or conn.rest_connect_count == 1 then
				util.printOk("call worker successfully connected to %s%s, socket '%s'", rec.host, rec.port and ":" .. rec.port or "", tostring(connectSocket.socket))
				break
			end
			-- else
			-- print("call worker connection failed to %s:%s", rec.host, rec.port)
		end
	end
	return {__name = "call-rest", socket = connectSocket} -- return connectSocket, err, driverErr
end
callRest.connect = restConnect

local function callAuth(query, restParam, conn)
	if not conn.authenticated and conn.code_method and conn.code_method.authenticate then
		local result = {conn = conn}
		-- local err, url, database, username = conn.code_function.authenticate(conn.auth_preference, sendCallFunction(query, result, conn))
		local err = conn.code_function.authenticate(conn.auth_preference, sendCallFunction(query, result, conn))
		if err == nil then
			conn.authenticated = dt.currentUtcString()
		end
		return err
	end
	local restAuthParam = util.clone(restParam)
	if conn.auth_preference then
		if conn.auth and conn.auth.call_url then
			local func = conn.auth.call_url
			if type(conn.auth_preference) == "string" then
				conn.auth_preference = dprf.prf(conn.auth_preference)
			end
			if peg.found(func, "{{") then
				for key, val in pairs(conn.auth_preference) do
					func = peg.replace(func, "{{" .. key .. "}}", tostring(val)) -- func is always a string
				end
			end
			restAuthParam.call_url = func
		end
	end
	if conn.auth and conn.auth.call_body or conn.call_body then
		restAuthParam.content = util.clone(conn.auth and conn.auth.call_body or conn.call_body)
	end
	local result, err = callRest.call(query, restAuthParam, conn)
	if conn.auth and conn.auth.disconnect and not conn.disconnect then
		restDisconnect(conn)
	end
	if result.data and conn.auth and result.data[conn.auth.auth_tag] then
		conn.query_parameter[conn.auth.auth_tag] = result.data[conn.auth.auth_tag]
	end
	util.recToRec(restParam.query_parameter, conn.query_parameter) -- copies access_token or other similar to query_parameter
	if conn.auth then
		restParam.call_url = peg.replace(restParam.call_url, "{{" .. conn.auth.auth_tag .. "}}", conn.query_parameter[conn.auth.auth_tag])
	end
	return err
	--[[
		{
		disconnect = true,
		call_method = conn.auth.call_method or conn.call_method,
		call_path = conn.auth.call_path or conn.call_path,
		call_url = conn.auth.call_url or conn.call_url,
		host = {host = conn.host, port = conn.port, connect_timeout = conn.connect_timeout},
		basic_authorization = conn.user and conn.password and base64.encode(conn.user..":"..conn.password),
		debug = conn.auth.parameter and conn.auth.parameter.show_sql or conn.debug or false,
		-- query_parameter = conn.auth.query_parameter or conn.query_parameter
	} --]]
end

function callRest.restParameter(prf, conn)
	local callFunction = prf.parameter and prf.parameter.call_tag and prf[prf.parameter.call_tag] or prf.call_url
	local restParam = {
		-- disconnect = false,
		call_method = prf.call_method or conn.call_method,
		call_path = prf.call_path or conn.call_path,
		call_url = callFunction or conn.call_url,
		host = {host = conn.host, port = conn.port, connect_timeout = conn.connect_timeout},
		debug = prf.parameter and prf.parameter.show_sql or conn.debug or false,
		query_parameter = prf.parameter
	}
	if conn.disconnect ~= nil then
		restParam.disconnect = conn.disconnect
	end
	if conn.bearer_authorization and conn.auth_preference then
		if conn.auth_preference.authorization and conn.auth_preference.authorization ~= "" then
			restParam.bearer_authorization = "Bearer " .. conn.auth_preference.authorization
		end
	elseif conn.basic_authorization and conn.auth_preference then
		if conn.auth_preference.user and conn.auth_preference.user ~= "" and conn.auth_preference.password and conn.auth_preference.password ~= "" then
			restParam.basic_authorization = conn.auth_preference.user and conn.auth_preference.user ~= "" and conn.auth_preference.password and conn.auth_preference.password ~= "" and "Basic " .. base64.encode(conn.auth_preference.user .. ":" .. conn.auth_preference.password)
		end
		if conn.auth_preference.authorization and conn.auth_preference.authorization ~= "" then
			restParam.basic_authorization = conn.auth_preference.authorization
		end
	end
	if prf.call_body and conn.call_body then
		restParam.content = util.tableCombine(conn.call_body, prf.call_body, "no-error")
	elseif prf.call_body or conn.call_body then
		restParam.content = util.clone(prf.call_body or conn.call_body)
	end
	if restParam.content then
		for key, val in pairs(restParam.content) do
			if prf.parameter[val] then
				restParam.content[key] = prf.parameter[val]
			end
		end
	end
	if conn.query_parameter then
		if conn.auth and conn.query_parameter[conn.auth.auth_tag] == "" then
			-- conn.query_parameter[auth.auth_tag] is empty, for example: "query_parameter": {"access_token": ""}
			local err = callAuth(prf, restParam, conn)
			if err then
				return {error = err}
			end
		else
			util.recToRec(restParam.query_parameter, conn.query_parameter)
		end
	end
	return restParam
end

function callRest.authCall(query, conn, content)
	local restParam = callRest.restParameter(query, conn)
	if content then
		util.recToRec(restParam.content, content)
		-- restParam = util.tableCombine(restParam, param, "no-error")
	end
	if not conn.authenticated and conn.code_method and conn.code_method.authenticate then
		if conn.code_function == nil then
			local execute = require "execute"
			conn.code_function = {}
			for key, val in pairs(conn.code_method) do
				conn.code_function[key] = execute.getFunction(val)
			end
		end
		local err = callAuth(query, restParam, conn)
		if err then
			return {error = err}
		end
	end
	if type(query.parameter and query.parameter.column) == "string" then
		query.parameter.column = dprf.prf(query.parameter.column)
	end
	local result = callRest.call(query, restParam, conn) -- not: callAuth()
	local ret = result.data
	if ret and ret[query.error_tag or conn.error_tag] then
		-- conn.query_parameter[auth.auth_tag] has expired or is invalid, for example: "query_parameter": {"access_token": "1234"}
		if conn.auth and ret[query.error_tag or conn.error_tag] == conn.auth.error_answer then
			local err = callAuth(query, restParam, conn)
			if err then
				return {error = err}
			end
			result = callRest.call(query, restParam, conn)
			ret = result.data
		end
		result.error = result.error or ret and ret[query.error_tag or conn.error_tag]
		if query.error_tag2 or conn.error_tag2 and ret[query.error_tag2 or conn.error_tag2] then
			result.error = result.error .. "\n" .. ret[query.error_tag2 or conn.error_tag2]
		end
	end
	return result, restParam
end

function callRest.call(query, param, conn)
	local result, paramTbl
	if type(param) == "string" then
		paramTbl = json.fromJson(param)
		if type(paramTbl) == "table" then
			paramTbl.content = param
		end
	else
		paramTbl = param
	end
	local err
	if type(paramTbl) ~= "table" then
		err = l("parameter is not a table, parameter: '%s'", tostring(param))
	else
		if conn.auth_preference and type(paramTbl.content) == "table" then
			for key, val in pairs(conn.auth_preference) do
				for key2, val2 in pairs(paramTbl.content) do
					if type(val2) == "string" and peg.found(val2, "{{" .. tostring(key) .. "}}") then
						if type(val) ~= "string" and type(paramTbl.content[key2]) == "string" then
							paramTbl.content[key2] = val
						else
							paramTbl.content[key2] = peg.replace(val2, "{{" .. tostring(key) .. "}}", tostring(val))
						end
					end
				end
			end
		end
		local connectSocket = conn.driver_conn and conn.driver_conn.socket
		if not connectSocket then
			if param.host then
				local callConn = restConnect(param.host, conn)
				connectSocket = callConn.socket
				if not connectSocket then
					return "error: " .. l("call worker socket connection failed, preference: %s", json.toJson(param.host)) -- error MUST be in english
				end
				if conn.driver_conn and conn.driver_conn.socket == nil then
					conn.driver_conn.socket = connectSocket
				end
			else
				return "error: " .. l("call worker socket connection failed, no parameter.host, preference: %s", json.toJson(loc.connectPrf)) -- error MUST be in english
			end
		end
		result = {conn = conn}
		if conn.code_function and conn.code_function.callQuery then
			err = conn.code_function.callQuery(query, result, paramTbl)
		else
			err = call(query, result, paramTbl)
		end
		if doDisconnect == true or paramTbl.disconnect == true then
			restDisconnect(conn)
		end
	end
	if paramTbl.debug and type(result.data) == "string" then
		local len = paramTbl.query_parameter and paramTbl.query_parameter.result_debug_length or conn.result_debug_length or loc.result_debug_length
		util.printInfo("* rest answer:\n%s\n", result.data:sub(1, len))
	end
	if err then
		result.error = "error: " .. tostring(err)
	end
	if err == nil and result and result.data then
		if type(result.data) ~= "table" then
			if not util.from4d() then
				if result.data:sub(1, 1) == "{" or result.data:sub(1, 1) == "[" then
					local data
					data, err = json.fromJson(result.data)
					if err then
						if result.error then
							result.error = result.error .. "\n  - json error: " .. err
						else
							result.error = err
						end
					else
						if data[paramTbl.query_parameter.data_tag or conn.data_tag] then
							result.data = data[paramTbl.query_parameter.data_tag or conn.data_tag]
						else
							result.data = data
						end
					end
				else
					result.data = tostring(result.data)
				end
			else
				result.data = tostring(result.data) -- json.toJsonRaw(ret)
			end
		end
	end
	if result.error then
		restDisconnect(conn)
		util.printError(result.error)
	end
	return result
end

return callRest
