--- lib/net/websocket.lua
-- Only RFC 6455 is supported.
-- see: permessage-deflate.lua, https://github.com/moteus/lua-websockets-permessage-deflate
--
local websocket = {}

local bit = require "bit"
local util = require "util"
local http = require "http"
local crypto = require "system/crypto"
local base64encode = require"base64".encode

local newTable = util.newTable
-- local wscompress = require "websocket-permessage-deflate" -- not in use yet
-- local zstd = require "compress/zstd" -- not in use yet

--[[
local function wsEcho(data)
	return data
end

local function wsChat(data)
	return data
end

local function wsNcJson(data)
	return data
end

local protocolHandler = { -- not in use now
  echo = wsEcho,     -- echos payload to ALL clients
  chat = wsChat,     -- sends "[name] message..." to all other 'chat' clients
  -- "command" = 'doCommand', 	-- accepts commands to the server
	["nc-json"] =  wsNcJson,
} ]]

local types = {[0x0] = "continuation", [0x1] = "text", [0x2] = "binary", [0x8] = "close", [0x9] = "ping", [0xa] = "pong"}

function websocket.upgrade(data)
	local header = {}
	--[[
	if http_ver() ~= 1.1 then
		return nil, "bad http version"
	end
	]]
	local val = http.getHeader(data, "Upgrade")
	if val ~= "websocket" then -- val:lower() ~= "websocket" then
		return nil, "bad 'Upgrade' request header"
	end
	-- already done in rest.lua: if http.headerValueFound(sock.header, "Connection", "Upgrade") then data = answer.func(socket) ...
	--[[
	if not val or val ~= "Upgrade" then
		return nil, "bad 'Connection' request header"
	end
	]]
	local key = http.getHeader(data, "Sec-WebSocket-Key")
	if not key then
		return nil, "bad 'Sec-WebSocket-Key' request header"
	end
	local ver = http.getHeader(data, "Sec-WebSocket-Version")
	if ver ~= "13" then
		return nil, "bad 'Sec-WebSocket-Version' request header"
	end
	header[#header + 1] = "Connection: Upgrade"
	header[#header + 1] = "Upgrade: websocket"
	local protocols = http.getHeader(data, "Sec-WebSocket-Protocol")
	-- may be many protocols?
	-- check for nc-json?
	if protocols then -- check for protocolHandler
		header[#header + 1] = "Sec-WebSocket-Protocol: " .. protocols
	end
	local sha1 = crypto.sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -- always same magic key
	header[#header + 1] = "Sec-WebSocket-Accept: " .. base64encode(sha1)
	-- header[#header + 1] = "Sec-WebSocket-Extensions: permessage-deflate; server_max_window_bits=10"
	-- deflate; server_max_window_bits=10
	-- WebSocket connection to 'ws://localhost:5959/rest/nc/ws' failed: Compressed bit must be 0 if no negotiated deflate-frame extension
	-- header["Content-Type"] = nil
	return header, "101 Switching Protocols" -- second return value is status
	--[[
	ngx.status = 101

	local sock
	sock, err = req_sock(true)
	if not sock then
			return nil, err
	end

	local max_payload_len, send_masked, timeout
	if opts then
			max_payload_len = opts.max_payload_len
			send_masked = opts.send_masked
			timeout = opts.timeout

			if timeout then
					sock:settimeout(timeout)
			end
	end

	return setmetatable({
			sock = sock,
			max_payload_len = max_payload_len or 65535,
			send_masked = send_masked,
	}, mt)
	]]
end

local sockDataPos = 0
local function sockData(sock, count)
	sockDataPos = sockDataPos + count
	return sock.data:sub(sockDataPos - count + 1, sockDataPos)
end

function websocket.receiveFrame(sock, max_payload_len, force_masking)
	sockDataPos = 0
	local data = sockData(sock, 2) -- , err = sock:receive(2)
	if not data then
		return nil, nil, "failed to receive the first 2 bytes: " -- .. err
	end
	local fst, snd = string.byte(data, 1, 2)
	local fin = bit.band(fst, 0x80) ~= 0
	-- print("fin: ", fin)
	if bit.band(fst, 0x70) ~= 0 then
		return nil, nil, "bad RSV1, RSV2, or RSV3 bits"
	end
	local opcode = bit.band(fst, 0x0f)
	-- print("opcode: ", tohex(opcode))
	if opcode >= 0x3 and opcode <= 0x7 then
		return nil, nil, "reserved non-control frames"
	end
	if opcode >= 0xb and opcode <= 0xf then
		return nil, nil, "reserved control frames"
	end
	local mask = bit.band(snd, 0x80) ~= 0
	--[[
	if debug then
		ngx_log(ngx_DEBUG, "recv_frame: mask bit: ", mask and 1 or 0)
	end
	]]
	if force_masking and not mask then
		return nil, nil, "frame unmasked"
	end
	local payload_len = bit.band(snd, 0x7f)
	-- print("payload len: ", payload_len)

	if payload_len == 126 then
		local data, err = sockData(sock, 2) -- sock:receive(2)
		if not data then
			return nil, nil, "failed to receive the 2 string.byte payload length: " .. (err or "unknown")
		end
		payload_len = bit.bor(bit.lshift(string.byte(data, 1), 8), string.byte(data, 2))
	elseif payload_len == 127 then
		local data, err = sockData(sock, 8) -- sock:receive(8)
		if not data then
			return nil, nil, "failed to receive the 8 string.byte payload length: " .. (err or "unknown")
		end
		if string.byte(data, 1) ~= 0 or string.byte(data, 2) ~= 0 or string.byte(data, 3) ~= 0 or string.byte(data, 4) ~= 0 then
			return nil, nil, "payload len too large"
		end
		local fifth = string.byte(data, 5)
		if bit.band(fifth, 0x80) ~= 0 then
			return nil, nil, "payload len too large"
		end
		payload_len = bit.bor(bit.lshift(fifth, 24), bit.lshift(string.byte(data, 6), 16), bit.lshift(string.byte(data, 7), 8), string.byte(data, 8))
	end
	if bit.band(opcode, 0x8) ~= 0 then
		-- being a control frame
		if payload_len > 125 then
			return nil, nil, "too long payload for control frame"
		end
		if not fin then
			return nil, nil, "fragmented control frame"
		end
	end
	-- print("payload len: ", payload_len, ", max payload len: ",
	-- max_payload_len)
	if payload_len > max_payload_len then
		return nil, nil, "exceeding max payload len"
	end
	local rest
	if mask then
		rest = payload_len + 4
	else
		rest = payload_len
	end
	-- print("rest: ", rest)
	if rest > 0 then
		local err
		data, err = sockData(sock, rest) --  -- sock:receive(rest)
		if not data then
			return nil, nil, "failed to read masking-len and payload: " .. (err or "unknown")
		end
	else
		data = ""
	end
	-- print("received rest")
	if opcode == 0x8 then
		-- being a close frame
		if payload_len > 0 then
			if payload_len < 2 then
				return nil, nil, "close frame with a body must carry a 2-string.byte" .. " status code"
			end
			local msg, code
			if mask then
				local fst = bit.bxor(string.byte(data, 4 + 1), string.byte(data, 1))
				local snd = bit.bxor(string.byte(data, 4 + 2), string.byte(data, 2))
				code = bit.bor(bit.lshift(fst, 8), snd)
				if payload_len > 2 then
					-- TODO string.buffer optimizations
					local bytes = newTable(payload_len - 2, 0)
					for i = 3, payload_len do
						bytes[i - 2] = string.char(bit.bxor(string.byte(data, 4 + i), string.byte(data, util.fmod(i - 1, 4) + 1)))
					end
					msg = table.concat(bytes)
				else
					msg = ""
				end
			else
				local fst = string.byte(data, 1)
				local snd = string.byte(data, 2)
				code = bit.bor(bit.lshift(fst, 8), snd)
				-- print("parsing unmasked close frame payload: ", payload_len)
				if payload_len > 2 then
					msg = string.sub(data, 3)
				else
					msg = ""
				end
			end
			return msg, "close", code
		end
		return "", "close", nil
	end

	local msg
	if mask then
		-- TODO string.buffer optimizations
		local bytes = newTable(payload_len, 0)
		for i = 1, payload_len do
			bytes[i] = string.char(bit.bxor(string.byte(data, 4 + i), string.byte(data, util.fmod(i - 1, 4) + 1)))
		end
		msg = table.concat(bytes)
	else
		msg = data
	end

	return msg, types[opcode], not fin and "again" or nil
end

local function buildFrame(fin, opcode, payload_len, payload, masking)
	-- XXX optimize this when we have string.buffer in LuaJIT 2.1
	local fst
	if fin then
		fst = bit.bor(0x80, opcode)
	else
		fst = opcode
	end

	local snd, extra_len_bytes
	if payload_len <= 125 then
		snd = payload_len
		extra_len_bytes = ""
	elseif payload_len <= 65535 then
		snd = 126
		extra_len_bytes = string.char(bit.band(bit.rshift(payload_len, 8), 0xff), bit.band(payload_len, 0xff))
	else
		if bit.band(payload_len, 0x7fffffff) < payload_len then
			return nil, "payload too big"
		end
		snd = 127
		-- XXX we only support 31-bit length here
		extra_len_bytes = string.char(0, 0, 0, 0, bit.band(bit.rshift(payload_len, 24), 0xff), bit.band(bit.rshift(payload_len, 16), 0xff), bit.band(bit.rshift(payload_len, 8), 0xff), bit.band(payload_len, 0xff))
	end

	local masking_key
	if masking then
		-- set the mask bit
		snd = bit.bor(snd, 0x80)
		local key = math.random(0xffffffff) -- todo: use crypto random
		masking_key = string.char(bit.band(bit.rshift(key, 24), 0xff), bit.band(bit.rshift(key, 16), 0xff), bit.band(bit.rshift(key, 8), 0xff), bit.band(key, 0xff))

		-- TODO string.buffer optimizations
		local bytes = newTable(payload_len, 0)
		for i = 1, payload_len do
			bytes[i] = string.char(bit.bxor(string.byte(payload, i), string.byte(masking_key, util.fmod(i - 1, 4) + 1)))
		end
		payload = table.concat(bytes)
	else
		masking_key = ""
	end
	return string.char(fst, snd) .. extra_len_bytes .. masking_key .. payload
end

function websocket.createFrame(sock, fin, opcode, payload, max_payload_len, masking)
	-- ngx.log(ngx.WARN, ngx._var.uri, ": masking: ", masking)
	if not payload then
		payload = ""
	elseif type(payload) ~= "string" then
		payload = tostring(payload)
	end
	local payload_len = #payload
	if payload_len > max_payload_len then
		return nil, "payload too big"
	end
	if bit.band(opcode, 0x8) ~= 0 then
		-- being a control frame
		if payload_len > 125 then
			return nil, "too much payload for control frame"
		end
		if not fin then
			return nil, "fragmented control frame"
		end
	end
	local frame, err = buildFrame(fin, opcode, payload_len, payload, masking)
	if not frame then
		return nil, "failed to build frame: " .. err
	end
	return frame
	--[[
	local bytes
	bytes, err = sock:send(frame)
	if not bytes then
		return nil, "failed to send frame: " .. err
	end
	return bytes
	]]
end

return websocket
