-- lib/system/poll.lua
-- partly copied from: https://github.com/chatid/fend/blob/master/poll.lua
local poll = {}

local ffi = require "mffi"
local C = ffi.C
local util = require "util"
local netPoll = require"system/net".poll
local bit = require "bit"
local socket = require "system/socket"
local dt = require "dt"
local peg = require "peg"
local dconn = require "dconn"
local dtCurrentString = dt.currentString
local band = bit.band
local print = util.print
local from4d = util.from4d()
local ioWrite = io.write

local coro = require "coro"
local useCoro = coro.useCoro()
local threadId = coro.threadId
local coroResume = coro.resume
local socketClosed = coro.socketClosed
local setThreadSocket = coro.setThreadSocket
local threadCount, connectionCount = coro.threadCount, coro.connectionCount
local addSocket -- delay load server.addSocket

ffi.cdef [[
	typedef struct {
			void* p1;
			// void* p2;
	} finalizer_st;
]]

local pollEvent, inCallback, outCallback, closeCallback, errorCallback
local fdsStruct, fds, nfds
local fdsListSize, fdsListAddCount, timeout, debugLevel
local totalPollCount, totalFdAddCount, totalFdRemoveCount
local socketIdx = {}
local listenEvents = C.POLLIN -- C.POLLOUT is set to each socket in socket.send() for yield() time only
local acceptEvents = C.POLLIN -- bit.bor(C.POLLIN, C.POLLPRI)
local connectEvents = C.POLLIN
local POLLHUP = C.POLLHUP
if util.isLinux() then
	POLLHUP = C.POLLRDHUP
end

if ffi.os == "Windows" then
	C = ffi.loadMsvcr()
else
	listenEvents = bit.bor(listenEvents, POLLHUP)
	acceptEvents = bit.bor(acceptEvents, POLLHUP)
	connectEvents = bit.bor(connectEvents, POLLHUP) -- POLLHUP is not allowed in Windows as input events, but is returned in revents
	-- https://stackoverflow.com/questions/55524397/why-when-i-add-pollhup-as-event-wsapoll-returns-error-invalid-arguments
end

local txtArr
local function fdArrShow()
	txtArr = {}
	for i = 1, nfds do
		txtArr[i] = tostring(fds[i - 1].fd) -- fds[idx].fd is uint64_t in windows 64
	end
	return "fds[" .. table.concat(txtArr, ", ") .. "], nfds=" .. nfds
end

local mt = {
	__gc = function(self)
		print("finalizer: free p1=" .. tostring(self.p1) .. ", " .. fdArrShow())
		if self.p1 ~= nil then
			dconn.disconnectAll(nil, "stop application")
			socket.closeAll()
			if nfds ~= 0 then
				util.printRed("error: open sockets after close: " .. fdArrShow())
			end
			ffi.C.free(self.p1)
			util.printOk("exit ok, bye")
		end
	end
}
local finalizer = ffi.metatype("finalizer_st", mt)

local size, newFds
local function expandFds(oldFds, countFds)
	size = ffi.sizeof("struct pollfd") * countFds
	if oldFds then
		print("poll.expandFds, free old fds: " .. tostring(oldFds))
		oldFds = nil -- necessary?
	end
	fdsStruct.p1 = C.realloc(fdsStruct.p1, size)
	if fdsStruct.p1 == nil then
		util.printError("Cannot re-allocate memory (poll.expandFds)")
	end
	newFds = ffi.cast("struct pollfd*", fdsStruct.p1)
	print(string.format("poll.expandFds: %d, old fds: %s, new fds: %s, size: %d", countFds, tostring(oldFds), tostring(newFds), size))
	return newFds

	--[[
	local newFds
	if oldFds then
		ffi.gc(oldFds , nil)
		-- oldFds will be used in realloc, but not garbage collect it, remove it's gc function
	end
	newFds = C.realloc(oldFds, ffi.sizeof("struct pollfd") * countFds)
	if newFds == nil then
		util.printError("Cannot re-allocate memory (poll.expandFds)")
	end
	ret = ffi.cast("struct pollfd*", newFds)
	return ffi.gc(ret, C.free) -- assign ffi.C.free for garbage collect
	--]]
end

local function clearAll()
	pollEvent = nil
	inCallback = nil -- runs this function when data has come in
	outCallback = nil -- runs this function when you can write out
	closeCallback = nil -- runs this function when you need to close socket
	errorCallback = nil -- runs this function when error has happened
	if not fdsStruct then
		fdsStruct = finalizer()
	end
	fdsListAddCount = 50
	fdsListSize = fdsListAddCount -- how many fd's can fit in to fds memory size
	fds = expandFds(fds, fdsListSize) -- ffi.C memory area containing all (max. fdsListSize) "struct pollfd":s
	nfds = 0 -- number of active fds
	timeout = 0
	debugLevel = 0
	totalPollCount = 0
	totalFdAddCount = 0
	totalFdRemoveCount = 0
end
clearAll()

local function fdArrIndex(fd)
	for i = nfds - 1, 0, -1 do -- better to loop from end, more likely to find correct
		if fds[i].fd == fd then
			return i
		end
	end
	return -1
end

function poll.setOutFlag(sock, enable)
	local idx = fdArrIndex(sock.socket)
	if idx < 0 then
		util.printError("poll.setOutFlag, socket '%s' not found, %s", tostring(sock.socket), fdArrShow())
		return
	end
	if enable then
		fds[idx].events = bit.bor(fds[idx].events, C.POLLOUT) -- sets POLLOUT bit
	else
		fds[idx].events = bit.bxor(fds[idx].events, C.POLLOUT) -- removes POLLOUT bit
	end
end

function poll.addFd(sock, socketType)
	local idx = fdArrIndex(sock.socket)
	local sockNum = tonumber(sock.socket)
	local events
	if socketType == "listen" then
		events = listenEvents
	elseif socketType == "accept" then
		events = acceptEvents
	elseif socketType == "connect" then
		events = connectEvents
	else
		util.printError("poll socket '%s' type '%s' is not valid, %s, idx %d, %s", tostring(sockNum), sock.info, idx, fdArrShow())
		return
	end
	sock.socket_type = socketType
	totalFdAddCount = totalFdAddCount + 1
	if idx >= 0 then
		-- is old socket number, is ok when we reuse addresses ???
		util.printError("poll socket '%s' was already added to array, %s, idx %d, %s", tostring(sockNum), sock.info, idx, fdArrShow())
		return
	else
		if nfds + 1 >= fdsListSize then -- expand nfds C memory area
			fdsListSize = fdsListSize + fdsListAddCount
			fds = expandFds(fds, fdsListSize)
		end
	end
	if socketIdx[sockNum] then
		util.printError("poll socket '%s' was already added to socket index, %s, idx %d, %s", tostring(sockNum), sock.info, idx, fdArrShow())
	else
		socketIdx[sockNum] = sock
	end
	fds[nfds].fd = sock.socket -- set C struct pollfd field fd, same as fds[nfds].fd = fd
	fds[nfds].events = events -- bor(C.POLLIN, C.POLLOUT, C.POLLRDHUP)
	-- fds[nfds].revents 	= 0 -- no need to set
	nfds = nfds + 1 -- fds is C-mem area and 0-based, so add it only in the end
	if addSocket == nil then
		addSocket = require"system/server".addSocket
	end
	addSocket(sock, socketType) -- this cerated a new coroutine and adds it to sock.thread and calls util.treadCreated() that calls setThreadSocket()
	if useCoro and not sock.thread then
		setThreadSocket(sock)
	end
	if debugLevel > 0 then
		if sock.thread then
			util.printInfo("  -- add poll socket '%s', %s, %s, %s", tostring(sockNum), threadId(sock.thread), tostring(sock.info), fdArrShow())
		else
			util.printInfo("  -- add poll socket '%s', %s, %s", tostring(sockNum), tostring(sock.info), fdArrShow())
		end
	end
end

local function removeFd(sock) -- , showError)
	totalFdRemoveCount = totalFdRemoveCount + 1
	local idx = fdArrIndex(sock.socket)
	local sockNum = tonumber(sock.socket)
	if idx < 0 then
		if nfds == 0 and (sock.info == nil or peg.found(sock.info, "listen")) then
			return
		end
		-- if showError ~= false then
		util.printError("poll remove fd socket '%s' was not found from array, socket type '%s', %s", tostring(sock.socket), tostring(sock.info), fdArrShow())
		-- end
		return
	end
	if socketIdx[sockNum] == nil then
		util.printError("poll remove fd socket '%s' was not found from socket, socket type '%s', %s", tostring(sock.socket), tostring(sock.info), fdArrShow())
		util.print("socket idx: %s", socketIdx)
	else
		socketIdx[sockNum] = nil
	end
	if idx ~= nfds - 1 then -- if not last item, move an item from end of list to fill the empty spot
		if debugLevel > 2 then
			print("  poll remove fd from middle: idx %d, socket '%s', socket type '%s', %s", idx + 1, tostring(sock.socket), tostring(sock.info), fdArrShow())
		end
		nfds = nfds - 1 -- decrease nfds count so that fds[nfds] is zero-based
		-- local lastfd = fds[nfds].fd
		-- local lastevent = fds[nfds].events
		fds[idx].fd = fds[nfds].fd -- lastfd
		fds[idx].events = fds[nfds].events -- lastevent
		if debugLevel > 0 then
			print("  poll removed fd from middle: idx %d, socket '%s', socket type '%s', %s", idx + 1, tostring(sock.socket), sock.info, fdArrShow())
		end
	else
		if debugLevel > 2 then
			print("  poll remove fd from end: idx %d, socket='%s', socket type '%s', %s", idx + 1, tostring(sock.socket), sock.info, fdArrShow())
		end
		nfds = nfds - 1 -- decrease nfds count so that fds[nfds] is zero-based
		fds[nfds].fd = -1
		-- fds[nfds].events = -1
		if debugLevel > 0 then
			print("  poll removed fd from end: idx %d, socket '%s', socket type '%s', %s", idx + 1, tostring(sock.socket), sock.info, fdArrShow())
		end
	end
	if not (sock.do_close or sock.closed) and sock.thread then
		sock.do_close = true
		if debugLevel > 0 then
			util.printInfo("  poll remove fd, resume socket thread: idx %d, socket='%s', socket type '%s', %s", idx + 1, tostring(sock.socket), sock.info, fdArrShow())
		end
		coroResume(sock.thread) -- resume so that coroutine will exit
	end
	if useCoro then
		socketClosed(sock)
	end
end
poll.removeFd = removeFd

function poll.removeAll(close_func)
	for i = 0, nfds - 1 do
		-- print("poll.removeAll: ", fds[i].fd)
		close_func(fds[i].fd)
	end
	clearAll()
end

function poll.fdCount()
	return nfds
end

function poll.pollCount()
	return totalPollCount
end

function poll.fdAddCount()
	return totalFdAddCount
end

function poll.fdRemoveCount()
	return totalFdRemoveCount
end

function poll.setTimeout(timeOut)
	-- util.print("  - poll set timeout: %d", timeOut)
	timeout = timeOut
end

function poll.getTimeout()
	return timeout
end

function poll.setInCallback(func)
	inCallback = func
end

function poll.setOutCallback(func)
	outCallback = func
end

function poll.setCloseCallback(func)
	closeCallback = func
end

function poll.setErrorCallback(func)
	errorCallback = func
end

-- http://www.greenend.org.uk/rjk/tech/poll.html
-- *** we use elseif here because if some event is really needed it will come on next poll *** --

local function pollEventNoDebug(evt, sock)
	if band(evt, POLLHUP) ~= 0 then
		sock.do_close = true
	end
	if band(evt, C.POLLIN) ~= 0 then
		sock.sending = "pollin"
		inCallback(sock)
		sock.sending = nil
	elseif band(evt, C.POLLOUT) ~= 0 then
		sock.sending = "pollout"
		outCallback(sock)
		sock.sending = nil
	elseif band(evt, C.POLLNVAL) ~= 0 then
		errorCallback(sock, "POLLNVAL")
	elseif band(evt, C.POLLERR) ~= 0 and band(evt, POLLHUP) == 0 then
		errorCallback(sock, "POLLERR")
	elseif band(evt, POLLHUP) == 0 then
		errorCallback(sock, "unknown")
	end
	if not sock.sending and band(evt, POLLHUP) ~= 0 then
		closeCallback(sock, "POLLHUP")
	end
end

local function pollEventDebug(evt, sock, i)
	local socketNum = tonumber(sock.socket) or "nil"
	ioWrite("\n")
	if band(evt, POLLHUP) ~= 0 then
		ioWrite(totalPollCount .. ". POLLHUP + ")
		sock.do_close = true
	end
	if band(evt, C.POLLIN) ~= 0 then
		print(totalPollCount .. ". POLLIN  : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		sock.sending = "pollin"
		inCallback(sock)
		sock.sending = nil
		-- sock.answer_size = 0 -- for debugging
	elseif band(evt, C.POLLOUT) ~= 0 then
		print(totalPollCount .. ". POLLOUT : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		sock.sending = "pollout"
		outCallback(sock)
		sock.sending = nil
	elseif band(evt, C.POLLNVAL) ~= 0 then
		-- POLLNVAL, output only
		print(totalPollCount .. ". POLLNVAL: idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		errorCallback(sock, "POLLNVAL")
	elseif band(evt, C.POLLERR) ~= 0 then
		-- POLLERR, output only
		print(totalPollCount .. ". POLLERR : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		errorCallback(sock, "POLLERR")
		--[=[elseif band(evt, C.POLLPRI) ~= 0 then
    -- POLLPRI Priority data may be read without blocking. This flag is not supported by the Microsoft Winsock provider.
    if debugLevel > 0 then print(totalPollCount..". POLLPRI : idx="..i..", evt="..evt..", fd="..tostring(fd)..", nfds="..nfds) end
    inCallback(fd)]=]
	elseif band(evt, POLLHUP) == 0 then
		print(totalPollCount .. ". UNKNOWN POLL EVENT  : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		errorCallback(sock, "unknown")
	end
	if not sock.sending and band(evt, POLLHUP) ~= 0 then
		print(totalPollCount .. ". POLLHUP : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		closeCallback(sock, "POLLHUP")
	end
end

local function setDebugLevel(level)
	debugLevel = level
	if debugLevel > 2 then
		pollEvent = pollEventDebug -- change to debug function
	else
		pollEvent = pollEventNoDebug -- change to debug function
	end
	-- print("poll debug level:  "..debugLevel, pollEvent)
end
poll.setDebugLevel = setDebugLevel

do
	local pollRet, pollEvt, sock, served, fdsItem
	local fastTimeout = nil
	local returnPoll = true
	local waitPollCount = 0
	local noPollCount = 0
	local resumeCount = 0
	local resumeArr = {}
	function poll.pollFd()
		repeat
			totalPollCount = totalPollCount + 1
			if fastTimeout == nil and not from4d then
				noPollCount = noPollCount + 1
				if noPollCount < 2 then
					fastTimeout = 3000 -- delay 3 seconds for 'waiting for calls' -message, bad for performance tests and less text on screen
				elseif noPollCount == 2 then
					fastTimeout = 10000 -- delay additional 10 seconds, then return for collectgarbage and for 'waiting for calls' -message until next real poll event (wait forever)
				else
					waitPollCount = waitPollCount + 1
				end
			end
			if not from4d and noPollCount > 1 then
				print("\n%d/%d. threads: %d, connections: %s, %s, %s, waiting for calls...", waitPollCount, totalPollCount, threadCount(), connectionCount(), fdArrShow(), dtCurrentString())
			end
			pollRet = netPoll(fds, nfds, fastTimeout or timeout)
			fastTimeout = nil
			returnPoll = true
			served = 0
			if pollRet == -1 then
				print(totalPollCount .. ". poll, nfds=" .. nfds)
				sock = socketIdx[tonumber(fds[0].fd)]
				socket.cleanup(sock, pollRet, "poll failed with error: ")
			else -- if pollRet > 0 or hasResume then -- a pollRet value of 0 indicates that the call timed out and no file descriptors have been selected
				-- loop all events
				resumeCount = 0
				for pollIdx = 1, nfds do -- listen sockets are first, serve them first instead of more compicated code of looping from previous
					if pollIdx > nfds then
						break
					end
					fdsItem = fds[pollIdx - 1]
					pollEvt = fdsItem.revents
					sock = socketIdx[tonumber(fdsItem.fd)]
					if sock == nil then
						util.printError("socket was not found from socket index (poll.pollFd 1): socket=" .. tostring(fds[0].fd) .. ", nfds=" .. nfds)
						break
					end
					if pollEvt == 0 and sock.resume and not sock.closed then
						resumeCount = resumeCount + 1
						resumeArr[resumeCount] = sock
						if sock.lock_topic then
							returnPoll = false
						end
						-- pollEvt = C.POLLIN
						-- pollRet = pollRet + 1
					elseif pollEvt ~= 0 then
						served = served + 1
						pollEvent(pollEvt, sock, pollIdx)
					end
				end
				if resumeCount > 0 then -- serve sock.resume sokets after serving real poll events first
					for resumeIdx = 1, resumeCount do
						pollEvent(C.POLLIN, resumeArr[resumeIdx], -resumeIdx)
					end
				end
			end
			if served > 0 or resumeCount > 0 then
				noPollCount = 0
				fastTimeout = 0
				returnPoll = false
			end
		until returnPoll or from4d
		return pollRet, noPollCount
	end
end

local anchor = {fdsStruct = fdsStruct, fds = fds}
local function printAnchor()
	util.printTable(anchor)
end
poll.printAnchor = printAnchor

return poll

--[[
local function pollEventNoDebug(evt, sock)
	if band(evt, POLLHUP) ~= 0 then
		closeCallback(sock, "POLLHUP")
	elseif band(evt, C.POLLIN) ~= 0 then
		inCallback(sock)
	elseif band(evt, C.POLLOUT) ~= 0 then
		outCallback(sock)
	elseif band(evt, C.POLLNVAL) ~= 0 then
		errorCallback(sock, "POLLNVAL")
	elseif band(evt, C.POLLERR) ~= 0 then
		errorCallback(sock, "POLLERR")
	else
		errorCallback(sock, "unknown")
	end
end

local function pollEventDebug(evt, sock, i)
	local socketNum = tonumber(sock.socket) or "nil"
	ioWrite("\n")
	if band(evt, POLLHUP) ~= 0 then
		print(totalPollCount .. ". POLLHUP : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		closeCallback(sock, "POLLHUP")
	elseif band(evt, C.POLLIN) ~= 0 then
		print(totalPollCount .. ". POLLIN  : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		inCallback(sock)
		-- sock.answer_size = 0 -- for debugging
	elseif band(evt, C.POLLOUT) ~= 0 then
		print(totalPollCount .. ". POLLOUT : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		outCallback(sock)
	elseif band(evt, C.POLLNVAL) ~= 0 then
		-- POLLNVAL, output only
		print(totalPollCount .. ". POLLNVAL: idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		errorCallback(sock, "POLLNVAL")
	elseif band(evt, C.POLLERR) ~= 0 then
		-- POLLERR, output only
		print(totalPollCount .. ". POLLERR : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		errorCallback(sock, "POLLERR")
		--[=[ elseif band(evt, C.POLLPRI) ~= 0 then
    --POLLPRI Priority data may be read without blocking. This flag is not supported by the Microsoft Winsock provider.
    if debugLevel > 0 then print(totalPollCount..". POLLPRI : idx="..i..", evt="..evt..", fd="..tostring(fd)..", nfds="..nfds) end
    inCallback(fd) ]=]
	else
		print(totalPollCount .. ". UNKNOWN POLL EVENT  : idx=" .. i .. ", evt=" .. evt .. ", nfds=" .. nfds .. ", socket=" .. socketNum)
		errorCallback(sock, "unknown")
	end
end ]]
