-- lib/net/http3-quiche-server.lua
-- /Users/pasi/installed/C/tls/quiche/examples/http3-server.c
package.path = "lib/?.lua;lib/?.lx;" .. package.path
package.path = "../lib/?.lua;../lib/?.lx;" .. package.path
package.path = "../../lib/?.lua;../../lib/?.lx;" .. package.path
require "start"

local ffi = require "mffi"
local C = ffi.C
local util = require "util"
local fs = require "fs"
local l = require"lang".l
local scrypt = require "scrypt"
-- local net = require "system/net"
local socket = require "system/socket"
local quiche = require "net/quiche"
local print = util.print
local QUICHE_H3_APPLICATION_PROTOCOL = quiche.QUICHE_H3_APPLICATION_PROTOCOL
quiche = quiche.lib

local config, http3_config
local function newConnStruct()
	return {
		timer = {delay = 0, data = ""},
		peer_addr = ffi.newNoAnchor("struct sockaddr_storage"),
		peer_addr_len = ffi.newNoAnchor("size_t[1]")
		-- int sock
		-- uint8_t cid[LOCAL_CONN_ID_LEN]
		-- quiche_conn *conn
		-- quiche.quiche_h3_conn *http3
		-- socklen_t peer_addr_len
		-- UT_hash_handle hh
	}
end
local conn_io
local connections = {} -- hash table of of conn_io
local conns = {} -- {int sock, struct conn_io *h}
local LOCAL_CONN_ID_LEN = 16
local MAX_DATAGRAM_SIZE = 1350
local quicheTextLen = #"quiche"
local MAX_TOKEN_LEN = quicheTextLen + ffi.sizeof("struct sockaddr_storage") + C.QUICHE_MAX_CONN_ID_LEN

local loopCount = 0
local recvBufLen = 65535
local recvBuf = C.malloc(recvBufLen)
local sendFlags = 0
local receiveFlags = 0
if util.isLinux() then
	sendFlags = C.MSG_NOSIGNAL
	receiveFlags = C.MSG_NOSIGNAL
end
-- local callTimeout = 0.8 --seconds
local sleepMillisec = 100
local out = ffi.newAnchor("uint8_t[?]", MAX_DATAGRAM_SIZE)

local function debugLog(line, argp)
	util.printInfo("%s", ffi.string(line))
end

local function flush_egress()
	while true do
		local written = tonumber(quiche.quiche_conn_send(conn_io.conn, out, MAX_DATAGRAM_SIZE))
		if written == C.QUICHE_ERR_DONE then
			-- util.print("done writing")
			break
		end
		if written < 0 then
			local err = l("failed to create packet: %d", written)
			util.printWarning(err)
			return err
		end
		local sent = conn_io.sock:sendto(out, written, sendFlags, conn_io.peer_addr, conn_io.peer_addr_len[0])
		-- sendto(conn_io.sock, out, written, 0, (struct sockaddr *)&conn_io.peer_addr, conn_io.peer_addr_len)
		if sent ~= written then
			local err = l("failed to send, sent: %d / written %d", sent, written)
			util.printWarning(err)
			return err
		end
		util.print("sent %d bytes", sent)
	end
	conn_io.timer.delay = quiche.quiche_conn_timeout_as_nanos(conn_io.conn) / 1e9
end

local function closeConnection()
	local stats = ffi.newNoAnchor("quiche_stats[1]")
	quiche.quiche_conn_stats(conn_io.conn, stats)
	util.print("connection closed, recv=%d sent=%d lost=%d rtt=%d ns", tonumber(stats[0].recv), tonumber(stats[0].sent), tonumber(stats[0].lost), tonumber(stats[0].rtt))
	quiche.quiche_conn_free(conn_io.conn)
	local hashId = ffi.string(conn_io.cid, LOCAL_CONN_ID_LEN)
	if not connections[hashId] then
		util.printWarning("closeConnection, connection does not exist, hashId '%s'", hashId)
	end
	connections[hashId] = nil
end

local function timeout_cb()
	quiche.quiche_conn_on_timeout(conn_io.conn)
	print("timeout")
	flush_egress()
	if quiche.quiche_conn_is_closed(conn_io.conn) then
		closeConnection()
	end
end

local function mint_token(dcid, dcid_len, addr, addr_len, token, token_len)
	ffi.copy(token, "quiche", quicheTextLen)
	ffi.copy(token + quicheTextLen, addr, addr_len[0])
	ffi.copy(token + quicheTextLen + addr_len[0], dcid, dcid_len[0])
	token_len[0] = quicheTextLen + addr_len[0] + dcid_len[0]
end

local function validate_token(token, token_len, addr, addr_len, odcid, odcid_len)
	local tokenPtr = token
	local len = token_len[0]
	if len < quicheTextLen or C.memcmp(tokenPtr, ffi.cast("char*", "quiche"), quicheTextLen) ~= 0 then
		return false
	end
	tokenPtr = tokenPtr + quicheTextLen
	len = len - quicheTextLen
	if len < addr_len[0] or C.memcmp(tokenPtr, addr, addr_len[0]) ~= 0 then
		return false
	end
	tokenPtr = tokenPtr + addr_len[0]
	len = len - addr_len[0]
	if odcid_len[0] < len then
		return false
	end
	ffi.copy(odcid, tokenPtr, len)
	odcid_len[0] = len
	return true
end

local new_cid = ffi.newNoAnchor("uint8_t[?]", LOCAL_CONN_ID_LEN)
local function newCid()
	local hashId
	repeat
		scrypt.randomBytes(new_cid, LOCAL_CONN_ID_LEN) -- sets random to conn_io.cid
		hashId = ffi.string(new_cid, LOCAL_CONN_ID_LEN)
		if connections[hashId] then
			util.printWarning("hashId %s exists", tostring(hashId))
		end
	until not connections[hashId]
	return new_cid, hashId
end

local function create_conn(odcid, odcid_len)
	conn_io = newConnStruct()
	conn_io.cid, conn_io.hashId = newCid()
	local conn = quiche.quiche_accept(conn_io.cid, LOCAL_CONN_ID_LEN, odcid, odcid_len[0], config)
	conn_io.sock = conns.sock
	conn_io.conn = conn
	conn_io.timer.data = conn_io
	-- HASH_ADD(hh, conns.h, cid, LOCAL_CONN_ID_LEN, conn_io)
	connections[conn_io.hashId] = conn_io
	util.print("new connection")
	return conn_io
end

local function for_each_header(name, name_len, value, value_len, argp)
	util.print("got HTTP header: %s=%s", ffi.string(name, name_len), ffi.string(value, value_len))
end

local function setHeader(headers, i, tag, value)
	headers[i].name = ffi.cast("const uint8_t *", tag)
	headers[i].name_len = #tag
	-- headers[i].name_len = ffi.sizeof(headers[i].name) - 1
	headers[i].value = ffi.cast("const uint8_t *", value)
	headers[i].value_len = #value
end

local function recv_cb()
	local startTime = util.seconds()
	local readBytes, done
	local peer_addr = ffi.newNoAnchor("struct sockaddr_storage")
	local peer_addr_len = ffi.newNoAnchor("socklen_t[1]")
	peer_addr_len[0] = ffi.sizeof(peer_addr)
	local connType = ffi.newNoAnchor("uint8_t[1]")
	local version = ffi.newNoAnchor("uint32_t[1]")
	local scid = ffi.newNoAnchor("uint8_t[?]", C.QUICHE_MAX_CONN_ID_LEN)
	local scid_len = ffi.newNoAnchor("size_t[1]")
	scid_len[0] = ffi.sizeof(scid)
	local dcid = ffi.newNoAnchor("uint8_t[?]", C.QUICHE_MAX_CONN_ID_LEN)
	local dcid_len = ffi.newNoAnchor("size_t[1]")
	dcid_len[0] = ffi.sizeof(dcid)
	local odcid = ffi.newNoAnchor("uint8_t[?]", C.QUICHE_MAX_CONN_ID_LEN)
	local odcid_len = ffi.newNoAnchor("size_t[1]")
	odcid_len[0] = ffi.sizeof(odcid)
	local token = ffi.newNoAnchor("uint8_t[?]", MAX_TOKEN_LEN)
	local token_len = ffi.newNoAnchor("size_t[1]")
	token_len[0] = ffi.sizeof(token)

	while true do
		readBytes = conns.sock:recvfrom(recvBuf, recvBufLen, receiveFlags, peer_addr, peer_addr_len)
		if readBytes < 0 then
			local errno = socket.lastError()
			if errno == C.EWOULDBLOCK or errno == C.EAGAIN then
				loopCount = loopCount + 1
				-- util.print(loopCount..". recv would block")
				break
			end
			util.print("failed to read")
			return "failed to read"
		end
		local rc = quiche.quiche_header_info(recvBuf, readBytes, LOCAL_CONN_ID_LEN, version, connType, scid, scid_len, dcid, dcid_len, token, token_len)
		if rc < 0 then
			util.print("failed to parse header: %d\n", done)
			break
		end

		-- HASH_FIND(hh, conns->h, dcid, dcid_len, conn_io)
		local hashId = ffi.string(dcid, dcid_len[0])
		conn_io = connections[hashId]
		if conn_io == nil then
			if not quiche.quiche_version_is_supported(version[0]) then
				util.print("version negotiation")
				local written = tonumber(quiche.quiche_negotiate_version(scid, scid_len[0], dcid, dcid_len[0], out, ffi.sizeof(out)))
				if written < 0 then
					util.print("failed to create vneg packet: %d", written)
					goto continue
				end
				local sent = conns.sock:sendto(out, written, sendFlags, peer_addr, peer_addr_len[0])
				if sent ~= written then
					util.print("failed to send")
					goto continue
				end
				util.print("sent %d bytes", sent)
				goto continue
			end
			if token_len[0] == 0 then
				util.print("stateless retry")
				mint_token(dcid, dcid_len, peer_addr, peer_addr_len, token, token_len)
				local written = tonumber(quiche.quiche_retry(scid, scid_len[0], dcid, dcid_len[0], dcid, dcid_len[0], token, token_len[0], version[0], out, ffi.sizeof(out)))
				-- local new_scid = newCid()
				-- local written = tonumber(quiche.quiche_retry(scid, scid_len[0], dcid, dcid_len[0], new_scid, LOCAL_CONN_ID_LEN, token, token_len[0], version[0], out, ffi.sizeof(out)))
				if written < 0 then
					util.print("failed to create retry packet: %d", written)
					goto continue
				end
				local sent = conns.sock:sendto(out, written, sendFlags, peer_addr, peer_addr_len[0])
				if sent ~= written then
					util.print("failed to send")
					goto continue
				end
				util.print("sent %d bytes", sent)
				goto continue
			end
			if not validate_token(token, token_len, peer_addr, peer_addr_len, odcid, odcid_len) then
				util.print("invalid address validation token")
				goto continue
			end
			conn_io = create_conn(odcid, odcid_len)
			if conn_io == nil then
				goto continue
			end
			ffi.copy(conn_io.peer_addr, peer_addr, peer_addr_len[0])
			conn_io.peer_addr_len[0] = peer_addr_len[0]
		end
		done = tonumber(quiche.quiche_conn_recv(conn_io.conn, recvBuf, readBytes))
		if done < 0 then
			util.print("failed to process packet: %d", done)
			break
		end
		util.print("recv %d bytes", done)
		local established = quiche.quiche_conn_is_established(conn_io.conn)
		if established then
			--[[
			local app_proto = ffi.newNoAnchor("const uint8_t *[1]")
			local app_proto_len = ffi.newNoAnchor("size_t[1]")
			quiche.quiche_conn_application_proto(conn_io.conn, app_proto, app_proto_len)
			util.print("connection established: %.*s\n", app_proto_len[0], app_proto[0])
			local config = quiche.quiche_h3_config_new()
			if config == nil then
				util.print("failed to create HTTP/3 config\n")
				return -1
			end
			conn_io.http3 = quiche.quiche_h3_conn_new_with_transport(conn_io.conn, config)
			if conn_io.http3 == nil then
				util.print("failed to create HTTP/3 connection\n")
				return -1
			end
			quiche.quiche_h3_config_free(config)
			]]
			if conn_io.http3 == nil then
				conn_io.http3 = quiche.quiche_h3_conn_new_with_transport(conn_io.conn, http3_config)
				if ffi.isNull(conn_io.http3) then
					util.print("failed to create HTTP/3 connection")
					goto continue
				end
			end

			local ev = ffi.newNoAnchor("quiche_h3_event[1]")
			while true do
				local s = quiche.quiche_h3_conn_poll(conn_io.http3, conn_io.conn, ev)
				if s < 0 then
					break
				end
				local evType = quiche.quiche_h3_event_type(ev)
				if evType == C.quiche_H3_EVENT_HEADERS then
					local rc = quiche.quiche_h3_event_for_each_header(ev, ffi.cast("int (*)(uint8_t *name, size_t name_len, uint8_t *value, size_t value_len, void *argp))", for_each_header), nil)
					if rc ~= 0 then
						print("failed to process headers")
					end
					local headerCount = 3
					local headers = ffi.newAnchor("quiche_h3_header[?]", headerCount)
					setHeader(headers, 1, ":status", "200")
					setHeader(headers, 2, ":server", "quiche")
					setHeader(headers, 3, ":content-length", "5")
					local err = quiche.quiche_h3_send_response(conn_io.http3, conn_io.conn, s, headers, headerCount, false)
					if err ~= 0 then
						print("quiche_h3_send_response failed")
					end
					err = quiche.quiche_h3_send_body(conn_io.http3, conn_io.conn, s, "byez", 5, true)
					if err ~= 0 then
						print("quiche_h3_send_body failed")
					end
					util.print("sent HTTP request %s", stream_id)
				elseif evType == C.quiche_H3_EVENT_DATA then
					util.print("got HTTP data")
					break
					--[[ local len = quiche.quiche_h3_recv_body(conn_io.http3, conn_io.conn, s, buf, sizeof(buf))
					if len <= 0 then
						break
					end
					print("%.*s", len, buf) ]]
				elseif evType == C.quiche_H3_EVENT_FINISHED then
					util.print("got H3_EVENT_FINISHED ")
					break
					--[[ if quiche.quiche_conn_close(conn_io.conn, true, 0, nil, 0) < 0 then
						print("failed to close connection")
					end ]]
				end
				quiche.quiche_h3_event_free(ev[0])
			end
		end
		::continue::
	end

	-- HASH_ITER(hh, conns.h, conn_io, tmp)
	local closed = {}
	local count = 0
	for _, rec in pairs(connections) do
		count = count + 1
		conn_io = rec
		flush_egress()
		if quiche.quiche_conn_is_closed(conn_io.conn) then
			closed[#closed + 1] = conn_io
		end
	end
	for _, rec in ipairs(connections) do
		count = count - 1
		conn_io = rec
		closeConnection()
	end
	util.print("connections: %d", count)
end

local function main(arg)
	local host = arg and arg[1] or "127.0.0.1"
	local port = tostring(arg and arg[2] or 8443)
	util.print("listening %s:%s", host, port)
	-- local blocking = 0
	-- local connectTimeout = 4
	local sock = socket.listen(host, port, "http3")
	if sock == nil then
		return
	end
	quiche.quiche_enable_debug_logging(ffi.cast("void (*)(const char *line, void *argp)", debugLog), nil)
	config = quiche.quiche_config_new(C.QUICHE_PROTOCOL_VERSION) -- 0xbabababa
	if ffi.isNull(config) then
		util.printError("failed to create quiche config")
		return
	end

	local err = quiche.quiche_config_load_cert_chain_from_pem_file(config, fs.filePathFix("~/nc/nc-server/lib/net/cert.crt"))
	if err ~= 0 then
		util.printError("quiche_config_load_cert_chain_from_pem_file error %d", err)
		return
	end
	err = quiche.quiche_config_load_priv_key_from_pem_file(config, fs.filePathFix("~/nc/nc-server/lib/net/cert.key"))
	if err ~= 0 then
		util.printError("quiche_config_load_priv_key_from_pem_file error %d", err)
		return
	end

	print("QUICHE_H3_APPLICATION_PROTOCOL: " .. QUICHE_H3_APPLICATION_PROTOCOL)
	print(#QUICHE_H3_APPLICATION_PROTOCOL)
	err = quiche.quiche_config_set_application_protos(config, ffi.cast("uint8_t *", QUICHE_H3_APPLICATION_PROTOCOL), #QUICHE_H3_APPLICATION_PROTOCOL)
	if err ~= 0 then
		util.printError("quiche_config_set_application_protos error %d", err)
		return
	end
	quiche.quiche_config_set_max_idle_timeout(config, 5000)
	quiche.quiche_config_set_max_udp_payload_size(config, MAX_DATAGRAM_SIZE)
	quiche.quiche_config_set_initial_max_data(config, 10000000)
	quiche.quiche_config_set_initial_max_stream_data_bidi_local(config, 1000000)
	quiche.quiche_config_set_initial_max_stream_data_bidi_remote(config, 1000000)
	quiche.quiche_config_set_initial_max_stream_data_uni(config, 1000000)
	quiche.quiche_config_set_initial_max_streams_bidi(config, 100)
	quiche.quiche_config_set_initial_max_streams_uni(config, 100)
	quiche.quiche_config_set_disable_active_migration(config, true)
	quiche.quiche_config_set_cc_algorithm(config, C.QUICHE_CC_RENO)
	http3_config = quiche.quiche_h3_config_new()
	if ffi.isNull(http3_config) then
		util.printError("failed to create http3 config")
	else
		util.print("HTTP/3 %s:%s\n", host, port)
		conns.sock = sock
		-- conns.h = nil
		local ret
		repeat
			ret = recv_cb()
			if not ret then
				-- ret = timeout_cb()
			end
			util.sleep(sleepMillisec)
		until ret
	end
	--[[ if conn_io.http3 then
		quiche.quiche_h3_conn_free(conn_io.http3)
	end
	quiche.quiche_conn_free(conn_io.conn) ]]
	quiche.quiche_h3_config_free(http3_config)
	quiche.quiche_config_free(config)
end

util.printInfo("* quiche version: %s", ffi.string(quiche.quiche_version()))
main({...})
