From 9c18bec53f8542c8fd7ed2d526eda729bad9298e Mon Sep 17 00:00:00 2001 From: cpunt Date: Mon, 18 May 2026 16:31:48 +0000 Subject: [PATCH 1/2] lua: implement rtl8380m wired provider --- src/services/hal.lua | 7 +- .../wired/providers/rtl8380m_http.lua | 593 ++++++++++++++++-- src/services/hal/managers/wired.lua | 46 +- tests/unit/hal/wired_provider_spec.lua | 57 +- 4 files changed, 641 insertions(+), 62 deletions(-) diff --git a/src/services/hal.lua b/src/services/hal.lua index 0d760d6c..86193d04 100644 --- a/src/services/hal.lua +++ b/src/services/hal.lua @@ -722,7 +722,7 @@ function HalService.start(conn, opts) return out end - function registry:terminate_caps(reason) + function registry:terminate_caps(_reason) -- reason is accepted for finaliser-shaped call sites; bus endpoints -- expose immediate unbind without a reason parameter. for _, class_caps in pairs(self.caps) do @@ -1090,7 +1090,8 @@ function HalService.start(conn, opts) manager_start_timeout_s, manager_logger, dev_ev_ch, - cap_emit_ch + cap_emit_ch, + conn )) end @@ -1255,7 +1256,7 @@ function HalService.start(conn, opts) svc:obs_log('info', { what = 'subscribed', topic = 'cfg/' .. svc.name }) while true do - local source, a, b = perform(op.named_choice({ + local source, a = perform(op.named_choice({ rpc = op.choice(registry:rpc_ops()), manager_fault = op.choice(manager_fault_ops()), cap_emit = cap_emit_ch:get_op(), diff --git a/src/services/hal/backends/wired/providers/rtl8380m_http.lua b/src/services/hal/backends/wired/providers/rtl8380m_http.lua index d6c2c928..e9b4b6a5 100644 --- a/src/services/hal/backends/wired/providers/rtl8380m_http.lua +++ b/src/services/hal/backends/wired/providers/rtl8380m_http.lua @@ -1,42 +1,19 @@ -- services/hal/backends/wired/providers/rtl8380m_http.lua -- --- Phase 1 stub for the pre-devicecode RTL8380M switch-fabric driver. --- --- This provider is deliberately telemetry-only. It defines the API the real --- manufacturer-firmware HTTP driver must implement, without baking HTTP details --- into Wired, Device or Net. --- --- Expected real-driver implementation: --- * fetch_snapshot_op(req) must perform the HTTP request(s) and return: --- { --- ok = true, --- provider_id = "switch-main", --- mode = "read_only", -- Phase 1; "writable" in Phase 2 --- writable = false, --- status = { state="available", available=true, ... }, --- surfaces = { --- ["port-1"] = { --- provider_surface_id = "port-1", --- kind = "ethernet-port", --- link = { state="up", speed_mbps=1000, duplex="full" }, --- attachment = { mode="access", vlan=100 }, --- poe = { state="off"|"delivering"|"fault", watts=0 }, --- }, --- ["uplink-cm5"] = { --- provider_surface_id = "uplink-cm5", --- kind = "switch-port", --- link = { state="up", speed_mbps=1000 }, --- attachment = { mode="trunk", vlans={10,11,12,100,101} }, --- }, --- }, --- topology = { ... provider-observed topology, semantic not HTTP-shaped ... }, --- } --- * Phase 1 control methods must return read_only. --- * Phase 2 should implement apply_attachments_op/set_poe_op/bounce_op with --- the same semantic request/response shape; no caller above HAL should know --- manufacturer URL paths, session cookies, page forms or register names. +-- Read-only RTL8380M switch provider backed by the firmware HTTP CGI surface. +-- Manufacturer paths, login forms and password encryption stay inside HAL; the +-- Device and Wired services only see semantic wired-provider snapshots. +local cjson = require 'cjson.safe' +local fibers = require 'fibers' local op = require 'fibers.op' +local sleep = require 'fibers.sleep' +local file = require 'fibers.io.file' +local exec = require 'fibers.io.exec' + +local blob_source = require 'devicecode.blob_source' +local resource = require 'devicecode.support.resource' +local http_sdk = require 'services.http.sdk' local contract = require 'services.hal.backends.wired.contract' local tablex = require 'shared.table' @@ -44,7 +21,12 @@ local M = {} local Provider = {} Provider.__index = Provider +local EXPONENT_HEX = '10001' +local DEFAULT_TIMEOUT_S = 10 +local DEFAULT_HTTP_CAP = 'main' + local function copy(v) return tablex.deep_copy(v) end +local function table_or_empty(v) return type(v) == 'table' and v or {} end local function default_surfaces() return { @@ -58,22 +40,515 @@ local function default_surfaces() } end +local function url_escape_form(s) + return (tostring(s or ''):gsub('([^%w%-%._~])', function (c) + return ('%%%02X'):format(string.byte(c)) + end)) +end + +local function urlencode_b64(s) + return (tostring(s or ''):gsub('[+/=]', function(c) + return ('%%%02X'):format(string.byte(c)) + end)) +end + +local B64 = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' + +local function base64_encode(s) + s = tostring(s or '') + local out = {} + for i = 1, #s, 3 do + local b1 = s:byte(i) or 0 + local b2 = s:byte(i + 1) or 0 + local b3 = s:byte(i + 2) or 0 + local n = b1 * 65536 + b2 * 256 + b3 + local c1 = math.floor(n / 262144) % 64 + local c2 = math.floor(n / 4096) % 64 + local c3 = math.floor(n / 64) % 64 + local c4 = n % 64 + out[#out + 1] = B64:sub(c1 + 1, c1 + 1) + out[#out + 1] = B64:sub(c2 + 1, c2 + 1) + out[#out + 1] = (i + 1 <= #s) and B64:sub(c3 + 1, c3 + 1) or '=' + out[#out + 1] = (i + 2 <= #s) and B64:sub(c4 + 1, c4 + 1) or '=' + end + return table.concat(out) +end + +local function resolve_env_value(v) + if type(v) ~= 'string' then return v end + local name = v:match('^%$ENV:([%w_]+)$') + or v:match('^env:([%w_]+)$') + or v:match('^%$([%w_]+)$') + if name then return os.getenv(name) end + return v +end + +local function normalise_auth(config) + local auth = table_or_empty(config.auth) + local username_ref = auth.username_env and ('$' .. tostring(auth.username_env)) or nil + local password_ref = auth.password_env and ('$' .. tostring(auth.password_env)) or nil + return { + username = resolve_env_value(auth.username or config.username or username_ref), + password = resolve_env_value(auth.password or config.password or password_ref), + enabled = config.login ~= false and auth.enabled ~= false, + } +end + +local function parse_base_url(url) + if type(url) ~= 'string' or url == '' then return nil end + local scheme, rest = url:match('^(https?)://(.+)$') + if not scheme then return nil end + local authority, path = rest:match('^([^/]+)(/.*)$') + if not authority then authority, path = rest, '' end + local host, port = authority:match('^%[([^%]]+)%]:(%d+)$') + if not host then host, port = authority:match('^([^:]+):(%d+)$') end + if not host then host = authority end + return { + scheme = scheme, + host = host, + port = tonumber(port), + prefix = path and path:gsub('/+$', '') or '', + } +end + +local function normalise_http_config(config) + local http = table_or_empty(config.http) + local base_url = config.base_url or config.url or http.base_url or http.url + local parsed = parse_base_url(base_url) + local scheme = http.scheme or config.scheme or (parsed and parsed.scheme) or 'http' + local host = http.host or config.host or (parsed and parsed.host) + local port = tonumber(http.port or config.port or (parsed and parsed.port)) + if not port then port = (scheme == 'https') and 443 or 80 end + return { + cap_id = http.cap_id or config.http_cap_id or DEFAULT_HTTP_CAP, + scheme = scheme, + host = host, + port = port, + prefix = http.prefix or config.path_prefix or (parsed and parsed.prefix) or '', + timeout_s = tonumber(http.timeout_s or config.timeout_s) or DEFAULT_TIMEOUT_S, + headers = copy(http.headers or config.headers or {}), + } +end + +local function append_dummy(path) + local sep = path:find('?', 1, true) and '&' or '?' + return path .. sep .. 'dummy=' .. tostring(math.floor(os.time() * 1000)) +end + +local function status_ok(status) + local n = tonumber(status) + return n ~= nil and n >= 200 and n < 300 +end + +local function decode_json(body) + local js, err = cjson.decode(body or '') + if js == nil then return nil, 'decode error: ' .. tostring(err) end + return js, nil +end + +local function path_join(prefix, path) + prefix = tostring(prefix or ''):gsub('/+$', '') + if prefix == '' then return path end + if path:sub(1, 1) == '/' then return prefix .. path end + return prefix .. '/' .. path +end + +local function is_up_value(v) + if v == true or v == 1 then return true end + local s = tostring(v or ''):lower() + return s == 'up' or s == 'linkup' or s == 'link-up' or s == 'connected' or s == 'on' +end + +local function is_down_value(v) + if v == false or v == 0 then return true end + local s = tostring(v or ''):lower() + return s == 'down' or s == 'linkdown' or s == 'link-down' or s == 'disconnected' or s == 'off' +end + +local function link_state(port) + local v = port.link or port.link_state or port.linkStatus or port.status or port.state or port.up + if is_up_value(v) then return 'up' end + if is_down_value(v) then return 'down' end + return 'unknown' +end + +local function speed_mbps(port) + local v = port.speed_mbps or port.speed or port.linkSpeed or port.rate + if type(v) == 'number' then return v end + local s = tostring(v or '') + local n = tonumber(s:match('(%d+)%s*[Gg]')) + if n then return n * 1000 end + n = tonumber(s:match('(%d+)%s*[Mm]')) + if n then return n end + return tonumber(s) +end + +local function port_key(port, idx) + local raw = port.id or port.port_id or port.port or port.name or port.ifname or idx + local token = tostring(raw or idx):lower():gsub('^port[%s_%-]*', '') + token = token:gsub('[^%w%._%-]', '-') + if token == '' then token = tostring(idx) end + if token:match('^port%-') then return token end + return 'port-' .. token +end + +local function array_items(t) + if type(t) ~= 'table' then return function () return nil end end + if #t > 0 then + local i = 0 + return function () + i = i + 1 + if i <= #t then return i, t[i] end + return nil + end + end + local keys = tablex.sorted_keys(t) + local i = 0 + return function () + i = i + 1 + local k = keys[i] + if k ~= nil then return k, t[k] end + return nil + end +end + +local function poe_by_port(ports_poe) + local out = {} + for k, rec in array_items(ports_poe) do + if type(rec) == 'table' then + out[port_key(rec, k)] = rec + end + end + return out +end + +local function normalise_poe(rec) + if type(rec) ~= 'table' then return nil end + local state = rec.state or rec.status or rec.poeStatus or rec.enable + local watts = tonumber(rec.watts or rec.power or rec.consumption or rec.outputPower) or 0 + local s = tostring(state or ''):lower() + if s == '' and watts > 0 then s = 'delivering' end + if s == 'on' or s == 'enabled' then s = watts > 0 and 'delivering' or 'on' end + if s == 'off' or s == 'disabled' then s = 'off' end + if s:find('fault', 1, true) or s:find('error', 1, true) then s = 'fault' end + return { + state = s ~= '' and s or 'unknown', + watts = watts, + limit_watts = tonumber(rec.limit_watts or rec.limit or rec.maxPower), + } +end + +local function merge_surface(base, observed) + local out = copy(base or {}) + for k, v in pairs(observed or {}) do out[k] = copy(v) end + return out +end + +local function build_surfaces(configured, stats) + local surfaces = copy(configured or default_surfaces()) + local ports = table_or_empty(stats and stats.ports) + local poe = poe_by_port(stats and stats.ports_poe) + local observed_count = 0 + + for k, port in array_items(ports) do + if type(port) == 'table' then + local id = port_key(port, k) + observed_count = observed_count + 1 + local poe_rec = poe[id] + local observed = { + provider_surface_id = id, + kind = 'ethernet-port', + capabilities = { access = true, trunk = true, poe = poe_rec ~= nil }, + link = { + state = link_state(port), + speed_mbps = speed_mbps(port), + duplex = port.duplex or port.linkDuplex, + }, + attachment = { + mode = port.mode or port.vlan_mode or 'unknown', + vlan = port.vlan, + vlans = port.vlans or port.tagged, + }, + poe = normalise_poe(poe_rec), + raw = copy(port), + } + surfaces[id] = merge_surface(surfaces[id], observed) + end + end + + return surfaces, observed_count +end + +local function make_unavailable(self, code, err) + return { + ok = false, + provider_id = self.id, + mode = self.mode, + writable = false, + code = code, + err = err, + status = { + state = 'unavailable', + available = false, + mode = self.mode, + driver = 'rtl8380m_http', + code = code, + err = err, + base_url_configured = self.http.host ~= nil, + }, + surfaces = copy(self.surfaces), + topology = copy(self.topology), + telemetry = copy(self.telemetry), + } +end + +local function write_tmp(scope, label, data, tmpdir) + local stream, err = file.tmpfile('rw-------', tmpdir) + if not stream then return nil, label .. '_tmpfile_failed:' .. tostring(err) end + scope:finally(function (_, status, primary) + resource.terminate_checked( + stream, + primary or status or label .. '_tmpfile_closed', + label .. '_tmpfile_cleanup_failed' + ) + end) + local path = stream:filename() + if type(path) ~= 'string' or path == '' then return nil, label .. '_tmpfile_path_unavailable' end + local n, werr = fibers.perform(stream:write_op(data)) + if n == nil then return nil, label .. '_write_failed:' .. tostring(werr) end + local ok, ferr = fibers.perform(stream:flush_op()) + if ok == nil then return nil, label .. '_flush_failed:' .. tostring(ferr) end + return path, nil +end + +local function run_checked(cmd, label) + local out, status, code, signal, err = fibers.perform(cmd:combined_output_op()) + if status == 'exited' and code == 0 then return out or '', nil end + if status == 'signalled' then return nil, label .. '_signalled:' .. tostring(signal) end + return nil, label .. '_failed:' .. tostring(err or out or code or status) +end + +local function encrypt_password_op(self, modulus_hex, password) + return fibers.run_scope_op(function (scope) + local asn1 = ([[ +asn1=SEQUENCE:pubkey +[pubkey] +modulus=INTEGER:0x%s +pubexp=INTEGER:0x%s +]]):format(tostring(modulus_hex or ''), EXPONENT_HEX) + + local asn1_path, aerr = write_tmp(scope, 'rtl8380m_asn1', asn1, self.tmpdir) + if not asn1_path then return nil, aerr end + local der_path, derr = write_tmp(scope, 'rtl8380m_der', '', self.tmpdir) + if not der_path then return nil, derr end + local pem_path, perr = write_tmp(scope, 'rtl8380m_pem', '', self.tmpdir) + if not pem_path then return nil, perr end + + local _, err = run_checked(self.exec.command( + 'openssl', 'asn1parse', '-genconf', asn1_path, '-out', der_path, '-noout' + ), 'openssl_asn1parse') + if err then return nil, err end + + _, err = run_checked(self.exec.command( + 'openssl', 'rsa', '-RSAPublicKey_in', '-inform', 'DER', '-in', der_path, '-out', pem_path, '-pubout' + ), 'openssl_rsa') + if err then return nil, err end + + local cmd = self.exec.command { + 'openssl', 'pkeyutl', '-encrypt', '-inkey', pem_path, '-pubin', + '-pkeyopt', 'rsa_padding_mode:pkcs1', + stdin = 'pipe', + stdout = 'pipe', + stderr = 'null', + } + local stdin, serr = cmd:stdin_stream() + if not stdin then return nil, 'openssl_pkeyutl_stdin_failed:' .. tostring(serr) end + local n, werr = fibers.perform(stdin:write_op(tostring(password or ''))) + if n == nil then return nil, 'openssl_pkeyutl_password_write_failed:' .. tostring(werr) end + resource.terminate_checked(stdin, 'password_written', 'openssl_pkeyutl_stdin_cleanup_failed') + + local ciphertext, status, code, signal, cerr = fibers.perform(cmd:output_op()) + if status ~= 'exited' or code ~= 0 then + if status == 'signalled' then return nil, 'openssl_pkeyutl_signalled:' .. tostring(signal) end + return nil, 'openssl_pkeyutl_failed:' .. tostring(cerr or code or status) + end + if type(ciphertext) ~= 'string' or ciphertext == '' then return nil, 'openssl_pkeyutl_empty_output' end + return urlencode_b64(base64_encode(ciphertext)), nil + end):wrap(function (status, report, encoded, err) + if status ~= 'ok' then return nil, err or (report and report.primary) or status end + return encoded, err + end) +end + function M.new(config, opts) config = config or {} + opts = opts or {} + local http = normalise_http_config(config) + local auth = normalise_auth(config) return setmetatable({ id = config.id or config.capability_id or 'switch-main', - base_url = config.base_url or config.url, mode = config.mode or 'read_only', + http = http, + auth = auth, telemetry = copy(config.telemetry or {}), surfaces = copy(config.surfaces or default_surfaces()), topology = copy(config.topology or {}), - logger = opts and opts.logger, + logger = opts.logger, + conn = opts.conn, + http_ref = opts.http_ref or config.http_ref, + exec = opts.exec or exec, + tmpdir = opts.tmpdir or config.tmpdir or os.getenv('TMPDIR') or '/tmp', }, Provider), nil end -function Provider:fetch_snapshot_op(_req) - -- Stub. The real driver should replace this method with HTTP-backed work. - return op.always({ +function Provider:_http_ref() + if self.http_ref then return self.http_ref end + if not self.conn then return nil end + return http_sdk.new_ref(self.conn, self.http.cap_id) +end + +function Provider:_uri(path) + local h = self.http + if type(h.host) ~= 'string' or h.host == '' then return nil, 'switch host not configured' end + local authority = h.host + if h.host:find(':', 1, true) and h.host:sub(1, 1) ~= '[' then authority = '[' .. h.host .. ']' end + local default_port = (h.scheme == 'https') and 443 or 80 + if tonumber(h.port) and tonumber(h.port) ~= default_port then + authority = authority .. ':' .. tostring(h.port) + end + return h.scheme .. '://' .. authority .. path_join(h.prefix, path), nil +end + +function Provider:_request_json_op(method, path, body, headers) + return op.guard(function () + local ref = self:_http_ref() + if not ref then return op.always(nil, 'http capability unavailable') end + local uri, uerr = self:_uri(path) + if not uri then return op.always(nil, uerr) end + local req_headers = copy(self.http.headers or {}) + for k, v in pairs(headers or {}) do req_headers[k] = v end + local args = { + uri = uri, + method = method or 'GET', + headers = req_headers, + } + if body ~= nil then args.body_source = blob_source.from_string(body) end + + return ref:open_exchange_op(args, { timeout = self.http.timeout_s }):wrap(function (reply, err) + if not reply then return nil, err end + local exchange = reply.exchange or reply + if not exchange or type(exchange.read_body_as_string_op) ~= 'function' then + return nil, 'http exchange handle unavailable' + end + local response_body, rerr = fibers.perform(exchange:read_body_as_string_op()) + local status = exchange.status and exchange:status() or nil + if exchange.shutdown_op then + local ok = fibers.perform(exchange:shutdown_op()) + if ok == nil and not rerr then rerr = 'http exchange shutdown failed' end + elseif exchange.terminate then + exchange:terminate('rtl8380m_http_done') + end + if rerr then return nil, rerr end + if not status_ok(status) then return nil, 'http status ' .. tostring(status) end + return decode_json(response_body) + end) + end) +end + +function Provider:_get_cgi_json_op(cmd, use_dummy) + local path = '/cgi/get.cgi?cmd=' .. url_escape_form(cmd) + if use_dummy then path = append_dummy(path) end + return self:_request_json_op('GET', path) +end + +function Provider:_post_cgi_json_op(path, payload, headers) + headers = headers or {} + headers['Content-Type'] = headers['Content-Type'] or 'application/x-www-form-urlencoded; charset=UTF-8' + headers['X-Requested-With'] = headers['X-Requested-With'] or 'XMLHttpRequest' + headers['Content-Length'] = tostring(#payload) + return self:_request_json_op('POST', path, payload, headers) +end + +function Provider:_login_op() + if self.auth.enabled == false then return op.always(true, nil) end + if not self.auth.username or not self.auth.password then return op.always(true, nil) end + + return fibers.run_scope_op(function () + local js, err = fibers.perform(self:_get_cgi_json_op('home_login', false)) + if not js then return false, 'failed to fetch modulus: ' .. tostring(err) end + local modulus = js.data and js.data.modulus + if type(modulus) ~= 'string' or modulus == '' then return false, 'login modulus missing' end + + local encoded, perr = fibers.perform(encrypt_password_op(self, modulus, self.auth.password)) + if not encoded then return false, 'encrypt error: ' .. tostring(perr) end + + local payload = ('_ds=1&username=%s&password=%s&_de=1'):format( + url_escape_form(self.auth.username), + encoded + ) + local _, post_err = fibers.perform(self:_post_cgi_json_op('/cgi/set.cgi?cmd=home_loginAuth', payload)) + if post_err then return false, post_err end + + for _ = 1, 10 do + local st_js, serr = fibers.perform(self:_get_cgi_json_op('home_loginStatus', false)) + if not st_js then return false, serr end + local status = st_js.data and st_js.data.status + if status == 'ok' then return true, nil end + if status == 'fail' then return false, 'login failed incorrect credentials' end + fibers.perform(sleep.sleep_op(1)) + end + return false, 'login timeout' + end):wrap(function (status, report, ok, err) + if status ~= 'ok' then return false, err or (report and report.primary) or status end + return ok, err + end) +end + +function Provider:_stats_op() + return fibers.run_scope_op(function () + local stats = { + system = { curr_time = 0, mem = 0, cpu = 0, power = 0, temp = 0 }, + ports = {}, + ports_poe = {}, + } + + local js, err = fibers.perform(self:_get_cgi_json_op('sys_sysTime', true)) + if not js then return nil, err end + stats.system.curr_time = js.data and js.data.sysCurrTime or nil + + js, err = fibers.perform(self:_get_cgi_json_op('sys_cpumem', true)) + if not js then return nil, err end + stats.system.cpu = js.data and js.data.cpu or nil + stats.system.mem = js.data and js.data.mem or nil + + js, err = fibers.perform(self:_get_cgi_json_op('panel_info', true)) + if not js then return nil, err end + stats.ports = js.data and js.data.ports or {} + + js, err = fibers.perform(self:_get_cgi_json_op('poe_poe', true)) + if not js then return nil, err end + stats.ports_poe = js.data and js.data.ports or {} + stats.system.power = js.data and js.data.devPower or nil + stats.system.temp = js.data and js.data.devTemp or nil + + return stats, nil + end):wrap(function (status, report, stats, err) + if status ~= 'ok' then return nil, err or (report and report.primary) or status end + return stats, err + end) +end + +function Provider:_snapshot_from_stats(stats) + local surfaces, observed_count = build_surfaces(self.surfaces, stats) + local telemetry = copy(self.telemetry or {}) + telemetry.system = copy(stats.system or {}) + telemetry.observed_ports = observed_count + + local topology = copy(self.topology or {}) + topology.provider = 'rtl8380m_http' + topology.port_count = observed_count + + return { ok = true, provider_id = self.id, mode = self.mode, @@ -83,13 +558,33 @@ function Provider:fetch_snapshot_op(_req) available = true, mode = self.mode, driver = 'rtl8380m_http', - stub = true, - base_url_configured = self.base_url ~= nil, + base_url_configured = self.http.host ~= nil, + host = self.http.host, }, - surfaces = copy(self.surfaces), - topology = copy(self.topology), - telemetry = copy(self.telemetry), - }) + surfaces = surfaces, + topology = topology, + telemetry = telemetry, + } +end + +function Provider:fetch_snapshot_op(_req) + return op.guard(function () + if type(self.http.host) ~= 'string' or self.http.host == '' then + return op.always(make_unavailable(self, 'host_not_configured', 'switch host not configured')) + end + return fibers.run_scope_op(function () + local logged_in, lerr = fibers.perform(self:_login_op()) + if logged_in ~= true then return make_unavailable(self, 'login_failed', lerr or 'login failed') end + local stats, serr = fibers.perform(self:_stats_op()) + if not stats then return make_unavailable(self, 'stats_failed', serr or 'stats failed') end + return self:_snapshot_from_stats(stats) + end):wrap(function (status, report, snapshot) + if status ~= 'ok' then + return make_unavailable(self, 'snapshot_failed', snapshot or (report and report.primary) or status) + end + return snapshot + end) + end) end function Provider:snapshot_op(req) return self:fetch_snapshot_op(req) end @@ -99,4 +594,10 @@ function Provider:set_poe_op(_req) return op.always(contract.read_only('set_poe' function Provider:bounce_op(_req) return op.always(contract.read_only('bounce')) end function Provider:terminate(_reason) return true end +M._test = { + build_surfaces = build_surfaces, + normalise_http_config = normalise_http_config, + resolve_env_value = resolve_env_value, +} + return M diff --git a/src/services/hal/managers/wired.lua b/src/services/hal/managers/wired.lua index 04ce81aa..b2dbe0ad 100644 --- a/src/services/hal/managers/wired.lua +++ b/src/services/hal/managers/wired.lua @@ -20,6 +20,7 @@ local state = { started = false, scope = nil, logger = nil, + conn = nil, dev_ev_ch = nil, cap_emit_ch = nil, drivers = {}, @@ -58,7 +59,12 @@ local function emit_state(class, id, key, payload) end local function emit_snapshot_now(provider_id, snapshot) - local ok, err = fibers.perform(emit_state('wired-provider', provider_id, 'status', snapshot.status or { state = 'available', available = snapshot.ok == true })) + local ok, err = fibers.perform(emit_state( + 'wired-provider', + provider_id, + 'status', + snapshot.status or { state = 'available', available = snapshot.ok == true } + )) if ok == false or ok == nil then return nil, err end ok, err = fibers.perform(emit_state('wired-provider', provider_id, 'surfaces', { surfaces = snapshot.surfaces or {} })) if ok == false or ok == nil then return nil, err end @@ -127,7 +133,9 @@ local function device_event_op(event_type, caps) end local function close_control_channels() - for _, ch in pairs(state.controls or {}) do if ch and type(ch.close) == 'function' then ch:close('reconfigured') end end + for _, ch in pairs(state.controls or {}) do + if ch and type(ch.close) == 'function' then ch:close('reconfigured') end + end state.controls = {} end @@ -155,9 +163,13 @@ local function normalise_provider_ids(config) for i = 1, #keys do local key = tostring(keys[i]) local rec = providers[key] - if type(rec) ~= 'table' then return nil, ('wired provider %s must be a table'):format(key) end + if type(rec) ~= 'table' then + return nil, ('wired provider %s must be a table'):format(key) + end local id = rec.id or key - if type(id) ~= 'string' or id == '' then return nil, ('wired provider %s id must be a non-empty string'):format(key) end + if type(id) ~= 'string' or id == '' then + return nil, ('wired provider %s id must be a non-empty string'):format(key) + end if seen[id] then return nil, ('duplicate wired provider id %s'):format(id) end seen[id] = true ids[#ids + 1] = id @@ -217,7 +229,7 @@ local function reconcile_device_caps(provider_ids) return true, nil end -function M.start_op(logger, dev_ev_ch, cap_emit_ch) +function M.start_op(logger, dev_ev_ch, cap_emit_ch, conn) return op.guard(function () if state.started then return op.always(true, nil) end local parent = fibers.current_scope() @@ -226,6 +238,7 @@ function M.start_op(logger, dev_ev_ch, cap_emit_ch) state.scope = child state.logger = logger + state.conn = conn state.dev_ev_ch = dev_ev_ch state.cap_emit_ch = cap_emit_ch state.controls = {} @@ -255,21 +268,35 @@ function M.apply_config_op(config) local id = provider_ids[i] local pcfg = configured_provider(config or {}, id) if not pcfg then - local eok, eerr = emit_snapshot_now(id, { status = { state = 'not_configured', available = false }, surfaces = {}, topology = {} }) + local eok, eerr = emit_snapshot_now(id, { + status = { state = 'not_configured', available = false }, + surfaces = {}, + topology = {}, + }) if eok ~= true then return false, eerr or 'wired provider status emit failed' end else local driver_config = {} for k, v in pairs(pcfg) do driver_config[k] = v end driver_config.id = driver_config.id or id - local driver, err = driver_mod.new(driver_config, { logger = state.logger, cap_emit_ch = state.cap_emit_ch }) - if not driver then return false, ('wired provider %s create failed: %s'):format(id, tostring(err)) end + local driver, err = driver_mod.new(driver_config, { + logger = state.logger, + cap_emit_ch = state.cap_emit_ch, + conn = state.conn, + }) + if not driver then + return false, ('wired provider %s create failed: %s'):format(id, tostring(err)) + end state.drivers[id] = driver local result = driver_result(id, 'snapshot', {}) if result.ok == true then local eok, eerr = emit_snapshot_now(id, result) if eok ~= true then return false, eerr or 'wired provider emit failed' end else - local eok, eerr = emit_snapshot_now(id, { status = { state = 'unavailable', available = false, err = result.err }, surfaces = {}, topology = {} }) + local eok, eerr = emit_snapshot_now(id, { + status = { state = 'unavailable', available = false, err = result.err }, + surfaces = {}, + topology = {}, + }) if eok ~= true then return false, eerr or 'wired provider status emit failed' end end end @@ -298,6 +325,7 @@ function M.terminate(reason) if state.scope then local scope = state.scope; state.scope = nil; scope:cancel(reason or 'terminated') end state.started = false state.logger = nil + state.conn = nil state.dev_ev_ch = nil state.cap_emit_ch = nil return true, nil diff --git a/tests/unit/hal/wired_provider_spec.lua b/tests/unit/hal/wired_provider_spec.lua index 186da8ae..bac0b5fe 100644 --- a/tests/unit/hal/wired_provider_spec.lua +++ b/tests/unit/hal/wired_provider_spec.lua @@ -1,9 +1,12 @@ local fibers = require 'fibers' +local op = require 'fibers.op' local provider = require 'services.hal.backends.wired.providers.rtl8380m_http' local tests = {} local function assert_true(v,msg) if v ~= true then error(msg or 'expected true',2) end end -local function assert_eq(a,b,msg) if a ~= b then error(msg or ('expected '..tostring(b)..', got '..tostring(a)),2) end end +local function assert_eq(a,b,msg) + if a ~= b then error(msg or ('expected '..tostring(b)..', got '..tostring(a)),2) end +end local function assert_not_nil(v,msg) if v == nil then error(msg or 'expected non-nil',2) end end @@ -25,17 +28,63 @@ function tests.test_wired_manager_provider_ids_are_config_driven() assert_eq(#ids, 0) end -function tests.test_rtl8380m_http_stub_is_read_only() +local function json_for_uri(uri) + if uri:find('cmd=sys_sysTime', 1, true) then + return '{"data":{"sysCurrTime":"2026-05-18 12:00:00"}}' + elseif uri:find('cmd=sys_cpumem', 1, true) then + return '{"data":{"cpu":12,"mem":34}}' + elseif uri:find('cmd=panel_info', 1, true) then + return '{"data":{"ports":[{"port":1,"link":"up","speed":"1000M","duplex":"full"}]}}' + elseif uri:find('cmd=poe_poe', 1, true) then + return '{"data":{"devPower":0,"devTemp":42,"ports":[{"port":1,"status":"off","power":0}]}}' + end + return '{"data":{}}' +end + +local function fake_http_ref(calls) + return { + open_exchange_op = function (_, args) + calls[#calls + 1] = args + local exchange = { + status = function () return '200' end, + read_body_as_string_op = function () return op.always(json_for_uri(args.uri), nil) end, + shutdown_op = function () return op.always(true, nil) end, + } + return op.always({ exchange = exchange }, nil) + end, + } +end + +function tests.test_rtl8380m_http_provider_uses_http_ref_and_is_read_only() fibers.run(function () - local p = assert(provider.new({ id = 'switch-main' })) + local calls = {} + local p = assert(provider.new({ + id = 'switch-main', + http = { host = '192.0.2.10', cap_id = 'main' }, + }, { + http_ref = fake_http_ref(calls), + })) local snap = fibers.perform(p:snapshot_op({})) assert_true(snap.ok) assert_eq(snap.provider_id, 'switch-main') - assert_true(snap.status.stub) + assert_eq(snap.status.driver, 'rtl8380m_http') + assert_eq(snap.surfaces['port-1'].link.state, 'up') + assert_eq(#calls, 4) + assert_eq(calls[1].method, 'GET') local res = fibers.perform(p:apply_attachments_op({})) assert_eq(res.code, 'read_only') assert_not_nil(res.err) end) end +function tests.test_rtl8380m_http_provider_reports_unavailable_without_host() + fibers.run(function () + local p = assert(provider.new({ id = 'switch-main' })) + local snap = fibers.perform(p:snapshot_op({})) + assert_eq(snap.ok, false) + assert_eq(snap.status.state, 'unavailable') + assert_eq(snap.status.code, 'host_not_configured') + end) +end + return tests From 59dc6115a1a24ef8164f1cfd582e729ed26f9ceb Mon Sep 17 00:00:00 2001 From: cpunt Date: Mon, 18 May 2026 17:28:13 +0000 Subject: [PATCH 2/2] lua: narrow rtl8380m wired provider http access --- src/services/hal.lua | 24 +- .../wired/providers/rtl8380m_http.lua | 14 +- src/services/hal/managers/wired.lua | 51 +--- src/services/http/policy.lua | 37 ++- src/services/http/transport/client.lua | 4 + .../http/transport/tolerant_http1.lua | 243 ++++++++++++++++++ tests/run.lua | 1 + tests/unit/hal/wired_provider_spec.lua | 22 ++ tests/unit/http/test_policy.lua | 35 ++- .../http/transport/test_tolerant_http1.lua | 107 ++++++++ 10 files changed, 486 insertions(+), 52 deletions(-) create mode 100644 src/services/http/transport/tolerant_http1.lua create mode 100644 tests/unit/http/transport/test_tolerant_http1.lua diff --git a/src/services/hal.lua b/src/services/hal.lua index 86193d04..06182de7 100644 --- a/src/services/hal.lua +++ b/src/services/hal.lua @@ -10,6 +10,7 @@ local channel = require "fibers.channel" local sleep = require "fibers.sleep" local tablex = require 'shared.table' +local http_sdk = require 'services.http.sdk' local perform = fibers.perform @@ -466,6 +467,15 @@ local function availability_flag(event_type) return event_type == 'added' end +local function narrow_http_ref(ref) + if ref == nil then return nil end + return { + status_op = function (_, opts) return ref:status_op(opts) end, + open_exchange_op = function (_, args, opts) return ref:open_exchange_op(args, opts) end, + exchange_op = function (_, args, opts) return ref:exchange_op(args, opts) end, + } +end + local function availability_payload(event_type, extra) local out = { state = availability_state(event_type), @@ -722,7 +732,7 @@ function HalService.start(conn, opts) return out end - function registry:terminate_caps(_reason) + function registry:terminate_caps(reason) -- reason is accepted for finaliser-shaped call sites; bus endpoints -- expose immediate unbind without a reason parameter. for _, class_caps in pairs(self.caps) do @@ -1082,6 +1092,14 @@ function HalService.start(conn, opts) component = 'manager', manager = name, }) + local deps + if name == 'wired' then + deps = { + http_ref_for = function (cap_id) + return narrow_http_ref(http_sdk.new_ref(conn, cap_id or 'main')) + end, + } + end return perform(manager_call_with_timeout_op( name, @@ -1091,7 +1109,7 @@ function HalService.start(conn, opts) manager_logger, dev_ev_ch, cap_emit_ch, - conn + deps )) end @@ -1256,7 +1274,7 @@ function HalService.start(conn, opts) svc:obs_log('info', { what = 'subscribed', topic = 'cfg/' .. svc.name }) while true do - local source, a = perform(op.named_choice({ + local source, a, b = perform(op.named_choice({ rpc = op.choice(registry:rpc_ops()), manager_fault = op.choice(manager_fault_ops()), cap_emit = cap_emit_ch:get_op(), diff --git a/src/services/hal/backends/wired/providers/rtl8380m_http.lua b/src/services/hal/backends/wired/providers/rtl8380m_http.lua index e9b4b6a5..2755a1c5 100644 --- a/src/services/hal/backends/wired/providers/rtl8380m_http.lua +++ b/src/services/hal/backends/wired/providers/rtl8380m_http.lua @@ -13,7 +13,6 @@ local exec = require 'fibers.io.exec' local blob_source = require 'devicecode.blob_source' local resource = require 'devicecode.support.resource' -local http_sdk = require 'services.http.sdk' local contract = require 'services.hal.backends.wired.contract' local tablex = require 'shared.table' @@ -47,7 +46,7 @@ local function url_escape_form(s) end local function urlencode_b64(s) - return (tostring(s or ''):gsub('[+/=]', function(c) + return (tostring(s or ''):gsub('[+/=]', function (c) return ('%%%02X'):format(string.byte(c)) end)) end @@ -394,8 +393,8 @@ function M.new(config, opts) surfaces = copy(config.surfaces or default_surfaces()), topology = copy(config.topology or {}), logger = opts.logger, - conn = opts.conn, http_ref = opts.http_ref or config.http_ref, + http_ref_for = opts.http_ref_for or config.http_ref_for, exec = opts.exec or exec, tmpdir = opts.tmpdir or config.tmpdir or os.getenv('TMPDIR') or '/tmp', }, Provider), nil @@ -403,8 +402,8 @@ end function Provider:_http_ref() if self.http_ref then return self.http_ref end - if not self.conn then return nil end - return http_sdk.new_ref(self.conn, self.http.cap_id) + if self.http_ref_for then return self.http_ref_for(self.http.cap_id) end + return nil end function Provider:_uri(path) @@ -431,10 +430,12 @@ function Provider:_request_json_op(method, path, body, headers) uri = uri, method = method or 'GET', headers = req_headers, + response_parser = 'tolerant-http1', + timeout_s = self.http.timeout_s, } if body ~= nil then args.body_source = blob_source.from_string(body) end - return ref:open_exchange_op(args, { timeout = self.http.timeout_s }):wrap(function (reply, err) + return ref:open_exchange_op(args):wrap(function (reply, err) if not reply then return nil, err end local exchange = reply.exchange or reply if not exchange or type(exchange.read_body_as_string_op) ~= 'function' then @@ -465,7 +466,6 @@ function Provider:_post_cgi_json_op(path, payload, headers) headers = headers or {} headers['Content-Type'] = headers['Content-Type'] or 'application/x-www-form-urlencoded; charset=UTF-8' headers['X-Requested-With'] = headers['X-Requested-With'] or 'XMLHttpRequest' - headers['Content-Length'] = tostring(#payload) return self:_request_json_op('POST', path, payload, headers) end diff --git a/src/services/hal/managers/wired.lua b/src/services/hal/managers/wired.lua index b2dbe0ad..2806fbab 100644 --- a/src/services/hal/managers/wired.lua +++ b/src/services/hal/managers/wired.lua @@ -20,7 +20,7 @@ local state = { started = false, scope = nil, logger = nil, - conn = nil, + http_ref_for = nil, dev_ev_ch = nil, cap_emit_ch = nil, drivers = {}, @@ -59,12 +59,7 @@ local function emit_state(class, id, key, payload) end local function emit_snapshot_now(provider_id, snapshot) - local ok, err = fibers.perform(emit_state( - 'wired-provider', - provider_id, - 'status', - snapshot.status or { state = 'available', available = snapshot.ok == true } - )) + local ok, err = fibers.perform(emit_state('wired-provider', provider_id, 'status', snapshot.status or { state = 'available', available = snapshot.ok == true })) if ok == false or ok == nil then return nil, err end ok, err = fibers.perform(emit_state('wired-provider', provider_id, 'surfaces', { surfaces = snapshot.surfaces or {} })) if ok == false or ok == nil then return nil, err end @@ -133,9 +128,7 @@ local function device_event_op(event_type, caps) end local function close_control_channels() - for _, ch in pairs(state.controls or {}) do - if ch and type(ch.close) == 'function' then ch:close('reconfigured') end - end + for _, ch in pairs(state.controls or {}) do if ch and type(ch.close) == 'function' then ch:close('reconfigured') end end state.controls = {} end @@ -163,13 +156,9 @@ local function normalise_provider_ids(config) for i = 1, #keys do local key = tostring(keys[i]) local rec = providers[key] - if type(rec) ~= 'table' then - return nil, ('wired provider %s must be a table'):format(key) - end + if type(rec) ~= 'table' then return nil, ('wired provider %s must be a table'):format(key) end local id = rec.id or key - if type(id) ~= 'string' or id == '' then - return nil, ('wired provider %s id must be a non-empty string'):format(key) - end + if type(id) ~= 'string' or id == '' then return nil, ('wired provider %s id must be a non-empty string'):format(key) end if seen[id] then return nil, ('duplicate wired provider id %s'):format(id) end seen[id] = true ids[#ids + 1] = id @@ -229,7 +218,7 @@ local function reconcile_device_caps(provider_ids) return true, nil end -function M.start_op(logger, dev_ev_ch, cap_emit_ch, conn) +function M.start_op(logger, dev_ev_ch, cap_emit_ch, deps) return op.guard(function () if state.started then return op.always(true, nil) end local parent = fibers.current_scope() @@ -238,7 +227,7 @@ function M.start_op(logger, dev_ev_ch, cap_emit_ch, conn) state.scope = child state.logger = logger - state.conn = conn + state.http_ref_for = type(deps) == 'table' and deps.http_ref_for or nil state.dev_ev_ch = dev_ev_ch state.cap_emit_ch = cap_emit_ch state.controls = {} @@ -268,35 +257,23 @@ function M.apply_config_op(config) local id = provider_ids[i] local pcfg = configured_provider(config or {}, id) if not pcfg then - local eok, eerr = emit_snapshot_now(id, { - status = { state = 'not_configured', available = false }, - surfaces = {}, - topology = {}, - }) + local eok, eerr = emit_snapshot_now(id, { status = { state = 'not_configured', available = false }, surfaces = {}, topology = {} }) if eok ~= true then return false, eerr or 'wired provider status emit failed' end else local driver_config = {} for k, v in pairs(pcfg) do driver_config[k] = v end driver_config.id = driver_config.id or id - local driver, err = driver_mod.new(driver_config, { - logger = state.logger, - cap_emit_ch = state.cap_emit_ch, - conn = state.conn, - }) - if not driver then - return false, ('wired provider %s create failed: %s'):format(id, tostring(err)) - end + local driver_opts = { logger = state.logger, cap_emit_ch = state.cap_emit_ch } + if driver_config.provider == 'rtl8380m_http' then driver_opts.http_ref_for = state.http_ref_for end + local driver, err = driver_mod.new(driver_config, driver_opts) + if not driver then return false, ('wired provider %s create failed: %s'):format(id, tostring(err)) end state.drivers[id] = driver local result = driver_result(id, 'snapshot', {}) if result.ok == true then local eok, eerr = emit_snapshot_now(id, result) if eok ~= true then return false, eerr or 'wired provider emit failed' end else - local eok, eerr = emit_snapshot_now(id, { - status = { state = 'unavailable', available = false, err = result.err }, - surfaces = {}, - topology = {}, - }) + local eok, eerr = emit_snapshot_now(id, { status = { state = 'unavailable', available = false, err = result.err }, surfaces = {}, topology = {} }) if eok ~= true then return false, eerr or 'wired provider status emit failed' end end end @@ -325,7 +302,7 @@ function M.terminate(reason) if state.scope then local scope = state.scope; state.scope = nil; scope:cancel(reason or 'terminated') end state.started = false state.logger = nil - state.conn = nil + state.http_ref_for = nil state.dev_ev_ch = nil state.cap_emit_ch = nil return true, nil diff --git a/src/services/http/policy.lua b/src/services/http/policy.lua index 9ae42e42..d6ac14c4 100644 --- a/src/services/http/policy.lua +++ b/src/services/http/policy.lua @@ -93,7 +93,13 @@ end function M.validate_listen_args(args) args = args or {} if type(args) ~= 'table' then return nil, 'invalid_args' end - local ok, ferr = require_only_fields(args, { host = true, port = true, path = true, tls = true, max_accept_queue = true }) + local ok, ferr = require_only_fields(args, { + host = true, + port = true, + path = true, + tls = true, + max_accept_queue = true, + }) if not ok then return nil, ferr end local out = {} if args.host ~= nil and type(args.host) ~= 'string' then return nil, 'invalid_args' end @@ -102,7 +108,11 @@ function M.validate_listen_args(args) end if args.path ~= nil and type(args.path) ~= 'string' then return nil, 'invalid_args' end if args.tls ~= nil and type(args.tls) ~= 'boolean' then return nil, 'invalid_args' end - if args.max_accept_queue ~= nil and (type(args.max_accept_queue) ~= 'number' or args.max_accept_queue < 0) then return nil, 'invalid_args' end + if args.max_accept_queue ~= nil + and (type(args.max_accept_queue) ~= 'number' or args.max_accept_queue < 0) + then + return nil, 'invalid_args' + end out.host = args.host out.port = args.port out.path = args.path @@ -123,10 +133,24 @@ local function host_denied(parsed_uri, opts) return false end +local function validate_response_parser(v) + if v == nil then return nil, nil end + if v == 'strict' or v == 'tolerant-http1' then return v, nil end + return nil, 'invalid_args' +end + function M.validate_exchange_args(args, opts) opts = opts or {} if type(args) ~= 'table' then return nil, 'invalid_args' end - local ok, ferr = require_only_fields(args, { uri = true, method = true, headers = true, body_source = true, response_sink = true }) + local ok, ferr = require_only_fields(args, { + uri = true, + method = true, + headers = true, + body_source = true, + response_sink = true, + response_parser = true, + timeout_s = true, + }) if not ok then return nil, ferr end local uri, uerr = M.validate_uri(args.uri, opts) if not uri then return nil, uerr end @@ -136,6 +160,11 @@ function M.validate_exchange_args(args, opts) if not method then return nil, merr end local headers, herr = copy_headers(args.headers) if herr then return nil, herr end + local response_parser, rperr = validate_response_parser(args.response_parser) + if rperr then return nil, rperr end + if response_parser == 'tolerant-http1' and uri.scheme ~= 'http' then return nil, 'unsupported_scheme' end + local timeout_s = args.timeout_s + if timeout_s ~= nil and (type(timeout_s) ~= 'number' or timeout_s < 0) then return nil, 'invalid_args' end local bodies, derr = body.validate_exchange_bodies(args) if not bodies then return nil, derr end @@ -147,6 +176,8 @@ function M.validate_exchange_args(args, opts) _uri = uri, body_source = bodies.source, response_sink = bodies.sink, + response_parser = response_parser, + timeout_s = timeout_s, }, nil end diff --git a/src/services/http/transport/client.lua b/src/services/http/transport/client.lua index c533b97f..34da4a38 100644 --- a/src/services/http/transport/client.lua +++ b/src/services/http/transport/client.lua @@ -5,6 +5,7 @@ local op = require 'fibers.op' local headers_mod = require 'services.http.headers' local terminate = require 'services.http.transport.terminate' +local tolerant_http1 = require 'services.http.transport.tolerant_http1' local M = {} @@ -37,6 +38,9 @@ end function M.open_exchange_op(driver, checked_args, opts) opts = opts or {} + if checked_args and checked_args.response_parser == 'tolerant-http1' then + return tolerant_http1.open_exchange_op(driver, checked_args, opts) + end return op.guard(function () local request_module, rerr = require_request(opts) if not request_module then return op.always(nil, rerr) end diff --git a/src/services/http/transport/tolerant_http1.lua b/src/services/http/transport/tolerant_http1.lua new file mode 100644 index 00000000..20a71d61 --- /dev/null +++ b/src/services/http/transport/tolerant_http1.lua @@ -0,0 +1,243 @@ +-- services/http/transport/tolerant_http1.lua +-- +-- Lenient HTTP/1.0 client transport for devices whose embedded HTTP servers do +-- not produce responses accepted by lua-http. This module remains below the +-- services.http boundary; callers still use cap/http through services.http.sdk. + +local op = require 'fibers.op' + +local M = {} + +local function require_socket(opts) + if opts and opts.socket_module then return opts.socket_module end + local ok, mod = pcall(require, 'cqueues.socket') + if not ok then return nil, mod end + return mod +end + +local function lower(s) + return tostring(s or ''):lower() +end + +local Headers = {} +Headers.__index = Headers + +local function new_headers(status, pairs) + local self = setmetatable({ + _pairs = { { ':status', tostring(status or 200) } }, + _map = { [':status'] = tostring(status or 200) }, + }, Headers) + for i = 1, #(pairs or {}) do + local k = tostring(pairs[i][1] or '') + local v = tostring(pairs[i][2] or '') + self._pairs[#self._pairs + 1] = { k, v } + self._map[lower(k)] = self._map[lower(k)] or v + end + return self +end + +function Headers:get(name) + return self._map[lower(name)] +end + +function Headers:each() + local i = 0 + return function () + i = i + 1 + local row = self._pairs[i] + if row then return row[1], row[2] end + end +end + +local Stream = {} +Stream.__index = Stream + +local function new_stream(body) + return setmetatable({ _body = body or '', _off = 1, _closed = false }, Stream) +end + +function Stream:get_next_chunk() + if self._closed then return nil end + if self._off > #self._body then return nil end + local chunk = self._body:sub(self._off) + self._off = #self._body + 1 + return chunk +end + +function Stream:get_body_chars(n) + if self._closed then return nil, 'closed' end + n = tonumber(n) or 0 + if n <= 0 then return '' end + local chunk = self._body:sub(self._off, self._off + n - 1) + self._off = self._off + #chunk + return chunk +end + +function Stream:get_body_as_string() + if self._closed then return nil, 'closed' end + local body = self._body:sub(self._off) + self._off = #self._body + 1 + return body +end + +function Stream:shutdown() + self._closed = true + return true +end + +function Stream:close() + return self:shutdown() +end + +local function body_from_iterator(iter) + if iter == nil then return nil, nil end + local out = {} + while true do + local chunk = iter() + if chunk == nil then break end + out[#out + 1] = tostring(chunk) + end + return table.concat(out), nil +end + +local function header_pairs(headers) + local out = {} + for k, v in pairs(headers or {}) do + if type(v) == 'table' then + for i = 1, #v do out[#out + 1] = { k, v[i] } end + else + out[#out + 1] = { k, v } + end + end + return out +end + +local function has_header(pairs, name) + local want = lower(name) + for i = 1, #pairs do + if lower(pairs[i][1]) == want then return true end + end + return false +end + +local function build_request(args) + local uri = args._uri or {} + local path = uri.path or '/' + local authority = uri.authority or uri.host or '' + local method = args.method or 'GET' + local body, berr = body_from_iterator(args._request_body) + if berr then return nil, berr end + + local headers = {} + for _, pair in ipairs(header_pairs(args.headers)) do + headers[#headers + 1] = pair + end + if body ~= nil and not has_header(headers, 'Content-Length') then + headers[#headers + 1] = { 'Content-Length', tostring(#body) } + end + + local lines = { + ('%s %s HTTP/1.0\r\n'):format(method, path), + ('Host: %s\r\n'):format(authority), + 'Accept: */*\r\n', + 'Connection: close\r\n', + } + for i = 1, #headers do + lines[#lines + 1] = ('%s: %s\r\n'):format(tostring(headers[i][1]), tostring(headers[i][2] or '')) + end + lines[#lines + 1] = '\r\n' + if body ~= nil then lines[#lines + 1] = body end + return table.concat(lines), nil +end + +local function parse_response(raw) + raw = tostring(raw or '') + local head, body = raw:match('^(.-)\r\n\r\n(.*)$') + if not head then head, body = raw:match('^(.-)\n\n(.*)$') end + if not head then + local i = raw:find('{', 1, true) + if i then return new_headers(200), new_stream(raw:sub(i)) end + return nil, 'invalid_http_response' + end + + local lines = {} + for line in head:gmatch('[^\r\n]+') do lines[#lines + 1] = line end + local status = lines[1] and lines[1]:match('^HTTP/%S+%s+(%d+)') or nil + status = tonumber(status) or 200 + local pairs = {} + for i = 2, #lines do + local k, v = lines[i]:match('^([^:]+):%s*(.*)$') + if k and k ~= '' then pairs[#pairs + 1] = { k, v or '' } end + end + return new_headers(status, pairs), new_stream(body or '') +end + +local function read_all(sock) + local chunks = {} + while true do + local ok, buf, err, part = pcall(function () return sock:read(4096) end) + if not ok then return nil, buf end + if buf and #buf > 0 then + chunks[#chunks + 1] = buf + elseif part and #part > 0 then + chunks[#chunks + 1] = part + end + if not buf then + if err and err ~= 'eof' then return nil, 'read error: ' .. tostring(err) end + break + end + end + return table.concat(chunks), nil +end + +local function write_request(sock, request) + local ok, wres, werr = pcall(function () return sock:write(request) end) + if not ok then return nil, wres end + if not wres then return nil, werr end + local fok, ferr = pcall(function () return sock:flush() end) + if not fok then return nil, ferr end + return true, nil +end + +function M.open_exchange_op(driver, args, opts) + opts = opts or {} + return op.guard(function () + local socket, serr = require_socket(opts) + if not socket then return op.always(nil, serr) end + local active = { socket = nil } + return driver:run_op('http.tolerant_http1.open_exchange', function () + local uri = args._uri or {} + if uri.scheme ~= nil and uri.scheme ~= 'http' then return nil, 'unsupported_scheme' end + local port = uri.port or ((uri.scheme == 'https') and 443 or 80) + local sock, err = socket.connect(uri.host, port) + if not sock then return nil, err or 'connect_failed' end + active.socket = sock + if sock.settimeout then sock:settimeout(args.timeout_s or opts.backend_timeout or opts.timeout) end + if sock.setmode then sock:setmode('b', 'b') end + + local request, rerr = build_request(args) + if not request then return nil, rerr end + local ok, werr = write_request(sock, request) + if not ok then return nil, 'write/flush failed: ' .. tostring(werr) end + local raw, read_err = read_all(sock) + if sock.close then sock:close() end + active.socket = nil + if not raw then return nil, read_err end + return parse_response(raw) + end, { + on_active_abort = function (reason) + local sock = active.socket + active.socket = nil + if sock and sock.close then pcall(function () sock:close(reason or 'aborted') end) end + end, + }) + end) +end + +M._test = { + build_request = build_request, + parse_response = parse_response, + new_stream = new_stream, +} + +return M diff --git a/tests/run.lua b/tests/run.lua index 0a89c468..9bfcdbfb 100644 --- a/tests/run.lua +++ b/tests/run.lua @@ -105,6 +105,7 @@ local files = { 'unit.support.test_config_watch_architecture', 'unit.support.test_service_events', 'unit.http.transport.test_cqueues_driver', + 'unit.http.transport.test_tolerant_http1', 'unit.http.transport.test_lua_http', 'unit.http.transport.test_websocket', 'unit.http.transport.test_terminate', diff --git a/tests/unit/hal/wired_provider_spec.lua b/tests/unit/hal/wired_provider_spec.lua index bac0b5fe..eca8a01c 100644 --- a/tests/unit/hal/wired_provider_spec.lua +++ b/tests/unit/hal/wired_provider_spec.lua @@ -71,12 +71,34 @@ function tests.test_rtl8380m_http_provider_uses_http_ref_and_is_read_only() assert_eq(snap.surfaces['port-1'].link.state, 'up') assert_eq(#calls, 4) assert_eq(calls[1].method, 'GET') + assert_eq(calls[1].response_parser, 'tolerant-http1') + assert_eq(calls[1].timeout_s, 10) local res = fibers.perform(p:apply_attachments_op({})) assert_eq(res.code, 'read_only') assert_not_nil(res.err) end) end +function tests.test_rtl8380m_http_provider_accepts_narrow_http_ref_factory() + fibers.run(function () + local calls = {} + local requested_cap + local p = assert(provider.new({ + id = 'switch-main', + http = { host = '192.0.2.10', cap_id = 'switch-http' }, + }, { + http_ref_for = function (cap_id) + requested_cap = cap_id + return fake_http_ref(calls) + end, + })) + local snap = fibers.perform(p:snapshot_op({})) + assert_true(snap.ok) + assert_eq(requested_cap, 'switch-http') + assert_eq(#calls, 4) + end) +end + function tests.test_rtl8380m_http_provider_reports_unavailable_without_host() fibers.run(function () local p = assert(provider.new({ id = 'switch-main' })) diff --git a/tests/unit/http/test_policy.lua b/tests/unit/http/test_policy.lua index b5ea9f9b..82a58cd5 100644 --- a/tests/unit/http/test_policy.lua +++ b/tests/unit/http/test_policy.lua @@ -54,11 +54,20 @@ end function M.test_validate_exchange_accepts_body_object_capabilities_and_rejects_remote_reference_tables() local source = { read_chunk_op = function () end, terminate = function () return true end } local sink = { write_chunk_op = function () end, terminate = function () return true end } - local checked = ok(policy.validate_exchange_args { uri = 'http://example.test/', method = 'POST', body_source = source, response_sink = sink }) + local checked = ok(policy.validate_exchange_args { + uri = 'http://example.test/', + method = 'POST', + body_source = source, + response_sink = sink, + }) eq(checked.body_source, source) eq(checked.response_sink, sink) - local bad, err = policy.validate_exchange_args { uri = 'http://example.test/', method = 'POST', body_source = { kind = 'remote-ref' } } + local bad, err = policy.validate_exchange_args { + uri = 'http://example.test/', + method = 'POST', + body_source = { kind = 'remote-ref' }, + } eq(bad, nil) eq(err, 'invalid_args') bad, err = policy.validate_exchange_args { uri = 'http://example.test/', method = 'POST', source = source } @@ -66,6 +75,28 @@ function M.test_validate_exchange_accepts_body_object_capabilities_and_rejects_r eq(err, 'invalid_args') end +function M.test_validate_exchange_accepts_tolerant_parser_and_timeout() + local checked = ok(policy.validate_exchange_args { + uri = 'http://example.test/', + method = 'GET', + response_parser = 'tolerant-http1', + timeout_s = 10, + }) + eq(checked.response_parser, 'tolerant-http1') + eq(checked.timeout_s, 10) + + local bad, err = policy.validate_exchange_args { uri = 'http://example.test/', response_parser = 'raw-socket' } + eq(bad, nil) + eq(err, 'invalid_args') + bad, err = policy.validate_exchange_args { uri = 'http://example.test/', timeout_s = -1 } + eq(bad, nil) + eq(err, 'invalid_args') + + bad, err = policy.validate_exchange_args { uri = 'https://example.test/', response_parser = 'tolerant-http1' } + eq(bad, nil) + eq(err, 'unsupported_scheme') +end + function M.test_validate_listen_defaults_loopback_ephemeral_port() local args = ok(policy.validate_listen_args {}) eq(args.host, '127.0.0.1') diff --git a/tests/unit/http/transport/test_tolerant_http1.lua b/tests/unit/http/transport/test_tolerant_http1.lua new file mode 100644 index 00000000..190e3c19 --- /dev/null +++ b/tests/unit/http/transport/test_tolerant_http1.lua @@ -0,0 +1,107 @@ +local fibers = require 'fibers' +local tolerant = require 'services.http.transport.tolerant_http1' + +local M = {} + +local function eq(a, b, msg) + if a ~= b then error((msg or 'assertion failed') .. ': expected ' .. tostring(b) .. ', got ' .. tostring(a), 2) end +end + +local function ok(v, msg) + if not v then error(msg or 'assertion failed', 2) end + return v +end + +function M.test_parse_response_accepts_simple_http1_body() + local headers, stream = tolerant._test.parse_response( + 'HTTP/1.1 200 OK\r\nServer: Hydra/0.1.8\r\nConnection: close\r\n\r\n{"data":{"ok":true}}' + ) + ok(headers) + eq(headers:get(':status'), '200') + eq(headers:get('server'), 'Hydra/0.1.8') + eq(stream:get_body_as_string(), '{"data":{"ok":true}}') +end + +function M.test_parse_response_accepts_body_when_status_line_is_missing() + local headers, stream = tolerant._test.parse_response('not http before json {"data":{"ok":true}}') + ok(headers) + eq(headers:get(':status'), '200') + eq(stream:get_body_as_string(), '{"data":{"ok":true}}') +end + +function M.test_build_request_adds_http10_host_and_single_content_length() + local chunks = { 'abc' } + local request = ok(tolerant._test.build_request({ + method = 'POST', + _uri = { path = '/cgi/set.cgi?cmd=home_loginAuth', authority = '172.28.100.9' }, + headers = { ['Content-Type'] = 'application/x-www-form-urlencoded', ['Content-Length'] = '3' }, + _request_body = function () return table.remove(chunks, 1) end, + })) + ok(request:find('POST /cgi/set.cgi?cmd=home_loginAuth HTTP/1.0\r\n', 1, true)) + ok(request:find('Host: 172.28.100.9\r\n', 1, true)) + ok(request:find('\r\n\r\nabc', 1, true)) + local _, count = request:gsub('Content%-Length:', '') + eq(count, 1) +end + +function M.test_open_exchange_uses_socket_module_inside_http_transport() + fibers.run(function () + local written + local socket_mod = { + connect = function (host, port) + eq(host, '172.28.100.9') + eq(port, 80) + local reads = { + 'HTTP/1.0 202 Accepted\r\n\r\n{"data":{"ok":true}}', + nil, + } + return { + settimeout = function (_, timeout) eq(timeout, 10) end, + setmode = function (_, rmode, wmode) eq(rmode, 'b'); eq(wmode, 'b') end, + write = function (_, data) written = data; return true end, + flush = function () return true end, + read = function () + local next_chunk = table.remove(reads, 1) + if next_chunk == nil then return nil, 'eof' end + return next_chunk + end, + close = function () return true end, + } + end, + } + local driver = { + run_op = function (_, _, fn) return fibers.always(fn()) end, + } + local headers, stream = fibers.perform(tolerant.open_exchange_op(driver, { + method = 'GET', + response_parser = 'tolerant-http1', + timeout_s = 10, + _uri = { + scheme = 'http', + host = '172.28.100.9', + port = 80, + path = '/cgi/get.cgi?cmd=panel_info', + authority = '172.28.100.9', + }, + }, { socket_module = socket_mod })) + ok(headers) + eq(headers:get(':status'), '202') + eq(stream:get_body_as_string(), '{"data":{"ok":true}}') + ok(written:find('GET /cgi/get.cgi?cmd=panel_info HTTP/1.0', 1, true)) + end) +end + +function M.test_open_exchange_rejects_https_for_tolerant_socket_transport() + fibers.run(function () + local driver = { + run_op = function (_, _, fn) return fibers.always(fn()) end, + } + local headers, err = fibers.perform(tolerant.open_exchange_op(driver, { + _uri = { scheme = 'https', host = 'example.test', port = 443, path = '/', authority = 'example.test' }, + }, { socket_module = { connect = function () error('connect should not run', 0) end } })) + eq(headers, nil) + eq(err, 'unsupported_scheme') + end) +end + +return M