pluginengine01 / crates /bex-core /src /host_state.rs
krystv's picture
Upload 107 files
3374e90 verified
use crate::engine::bex::plugin::common::Attr;
use crate::engine::bex::plugin::http::{Method, Request, Response};
use crate::engine::bex::plugin::js::JsOpts;
use crate::engine::bex::plugin::log::Level;
use crate::http_service::HttpHostService;
use bex_db::BexDb;
use bex_js::JsPool;
use bex_types::Manifest;
use std::sync::Arc;
use wasmtime_wasi::IoView;
/// Memory limiter stored per-HostState to enforce memory limits on WASM plugins.
pub struct BexResourceLimiter {
pub max_memory_bytes: usize,
}
impl wasmtime::ResourceLimiter for BexResourceLimiter {
fn memory_growing(
&mut self,
_current: usize,
desired: usize,
_maximum: Option<usize>,
) -> anyhow::Result<bool> {
if desired > self.max_memory_bytes {
tracing::warn!(
desired,
max = self.max_memory_bytes,
"Memory limit exceeded"
);
Ok(false)
} else {
Ok(true)
}
}
fn table_growing(
&mut self,
_current: usize,
desired: usize,
_maximum: Option<usize>,
) -> anyhow::Result<bool> {
if desired > 1_000_000 {
Ok(false)
} else {
Ok(true)
}
}
}
/// State passed to every WASM store instance. Implements all host traits
/// required by the WIT imports (http, kv, secrets, log, clock, rng)
/// and WASI interfaces for component model compatibility.
pub struct HostState {
pub http_client: Arc<HttpHostService>,
pub db: Arc<BexDb>,
pub js_pool: Arc<JsPool>,
pub runtime: Arc<tokio::runtime::Runtime>,
pub plugin_id: String,
pub manifest: Arc<Manifest>,
pub user_agent: Arc<str>,
pub http_timeout_ms: u32,
pub max_response_bytes: u64,
pub start_mono: std::time::Instant,
pub wasi: wasmtime_wasi::WasiCtx,
pub table: wasmtime::component::ResourceTable,
pub limiter: BexResourceLimiter,
}
impl HostState {
pub fn new(
http_client: Arc<HttpHostService>,
db: Arc<BexDb>,
js_pool: Arc<JsPool>,
runtime: Arc<tokio::runtime::Runtime>,
plugin_id: String,
manifest: Arc<Manifest>,
user_agent: Arc<str>,
http_timeout_ms: u32,
max_response_bytes: u64,
memory_limit_mb: u32,
) -> Self {
// Locked-down WASI: no inherited stdout/stderr, no filesystem, no env, no sockets
let wasi = wasmtime_wasi::WasiCtxBuilder::new().build();
Self {
http_client,
db,
js_pool,
runtime,
plugin_id,
manifest,
user_agent,
http_timeout_ms,
max_response_bytes,
start_mono: std::time::Instant::now(),
wasi,
table: wasmtime::component::ResourceTable::new(),
limiter: BexResourceLimiter {
max_memory_bytes: memory_limit_mb as usize * 1024 * 1024,
},
}
}
}
impl IoView for HostState {
fn table(&mut self) -> &mut wasmtime::component::ResourceTable {
&mut self.table
}
}
impl wasmtime_wasi::WasiView for HostState {
fn ctx(&mut self) -> &mut wasmtime_wasi::WasiCtx {
&mut self.wasi
}
}
// ── Implement host traits for each WIT import ──────────────────────────
use crate::engine::bex::plugin;
/// Extract the host portion from a URL string using the `url` crate for robust parsing.
/// Returns just the hostname (e.g., "example.com" from "https://example.com/path?q=1")
fn extract_host(url_str: &str) -> String {
url::Url::parse(url_str)
.ok()
.and_then(|u| u.host_str().map(|h| h.to_string()))
.unwrap_or_else(|| {
// Fallback: simple extraction for non-standard URLs
let after_scheme =
if let Some(pos) = url_str.find("://") { &url_str[pos + 3..] } else { url_str };
let host_port = after_scheme.split('/').next().unwrap_or(after_scheme);
host_port.split(':').next().unwrap_or(host_port).to_string()
})
}
impl plugin::common::Host for HostState {}
impl plugin::http::Host for HostState {
fn send_request(&mut self, req: Request) -> Result<Response, plugin::common::PluginError> {
let method_str = match req.method {
Method::Get => "GET",
Method::Post => "POST",
Method::Put => "PUT",
Method::Delete => "DELETE",
Method::Head => "HEAD",
Method::Patch => "PATCH",
Method::Options => "OPTIONS",
};
let headers: Vec<(String, String)> = req
.headers
.iter()
.map(|a| (a.key.clone(), a.value.clone()))
.collect();
let timeout = req.timeout_ms.or(Some(self.http_timeout_ms));
let url = req.url.clone();
let body = req.body.clone();
// KEY FIX: Use the shared runtime instead of creating a new one per call.
let result = self.runtime.block_on(async {
self.http_client
.send_request(method_str, &url, headers, body, timeout)
.await
});
match result {
Ok((status, body, resp_headers, final_url)) => {
let body = if body.len() as u64 > self.max_response_bytes {
&body[..self.max_response_bytes as usize]
} else {
&body
};
// Domain enforcement: check the manifest's network allowlist
let host_part = extract_host(&final_url);
if !self.manifest.allows_host(&host_part) {
tracing::warn!(
plugin = %self.plugin_id,
url = %final_url,
host = %host_part,
"Domain not in allowlist"
);
return Err(plugin::common::PluginError::Forbidden);
}
Ok(Response {
status,
headers: resp_headers
.into_iter()
.map(|(k, v)| Attr { key: k, value: v })
.collect(),
body: body.to_vec(),
cached: false,
final_url,
})
}
Err(e) => Err(plugin::common::PluginError::Network(e.to_string())),
}
}
}
/// Scoped KV wrapper that enforces manifest permissions.
/// A plugin with `storage: false` cannot write to KV.
struct ScopedKv<'a> {
db: &'a BexDb,
plugin_id: &'a str,
storage_allowed: bool,
}
impl<'a> ScopedKv<'a> {
fn new(db: &'a BexDb, plugin_id: &'a str, manifest: &Manifest) -> Self {
Self {
db,
plugin_id,
storage_allowed: manifest.storage,
}
}
fn set(&self, key: &str, value: &[u8], _ttl_seconds: Option<u32>) -> bool {
if !self.storage_allowed {
tracing::warn!(plugin = %self.plugin_id, "KV write blocked: storage=false");
return false;
}
self.db.kv_set(self.plugin_id, key, value).unwrap_or(false)
}
fn get(&self, key: &str) -> Option<Vec<u8>> {
if !self.storage_allowed {
return None;
}
self.db.kv_get(self.plugin_id, key).ok().flatten()
}
fn remove(&self, key: &str) -> bool {
if !self.storage_allowed {
return false;
}
self.db.kv_remove(self.plugin_id, key).unwrap_or(false)
}
fn keys(&self, prefix: &str) -> Vec<String> {
if !self.storage_allowed {
return vec![];
}
self.db.kv_keys(self.plugin_id, prefix).unwrap_or_default()
}
}
/// Scoped secrets wrapper.
struct ScopedSecrets<'a> {
db: &'a BexDb,
plugin_id: &'a str,
secrets_allowed: bool,
}
impl<'a> ScopedSecrets<'a> {
fn new(db: &'a BexDb, plugin_id: &'a str, manifest: &Manifest) -> Self {
// Secrets are allowed if the manifest declares secret keys
Self {
db,
plugin_id,
secrets_allowed: !manifest.secrets.is_empty(),
}
}
fn get(&self, key: &str) -> Option<String> {
if !self.secrets_allowed {
return None;
}
self.db.secret_get(self.plugin_id, key).ok().flatten()
}
}
impl plugin::kv::Host for HostState {
fn set(&mut self, key: String, value: Vec<u8>, ttl_seconds: Option<u32>) -> bool {
let scoped = ScopedKv::new(&self.db, &self.plugin_id, &self.manifest);
scoped.set(&key, &value, ttl_seconds)
}
fn get(&mut self, key: String) -> Option<Vec<u8>> {
let scoped = ScopedKv::new(&self.db, &self.plugin_id, &self.manifest);
scoped.get(&key)
}
fn remove(&mut self, key: String) -> bool {
let scoped = ScopedKv::new(&self.db, &self.plugin_id, &self.manifest);
scoped.remove(&key)
}
fn keys(&mut self, prefix: String) -> Vec<String> {
let scoped = ScopedKv::new(&self.db, &self.plugin_id, &self.manifest);
scoped.keys(&prefix)
}
}
impl plugin::secrets::Host for HostState {
fn get(&mut self, key: String) -> Option<String> {
let scoped = ScopedSecrets::new(&self.db, &self.plugin_id, &self.manifest);
scoped.get(&key)
}
}
impl plugin::log::Host for HostState {
fn write(&mut self, level: Level, msg: String, fields: Vec<Attr>) {
let fields_str: Vec<String> =
fields.iter().map(|a| format!("{}={}", a.key, a.value)).collect();
let fields_joined = fields_str.join(",");
match level {
Level::Trace => {
tracing::trace!(plugin = %self.plugin_id, fields = %fields_joined, "{}", msg)
}
Level::Debug => {
tracing::debug!(plugin = %self.plugin_id, fields = %fields_joined, "{}", msg)
}
Level::Info => {
tracing::info!(plugin = %self.plugin_id, fields = %fields_joined, "{}", msg)
}
Level::Warn => {
tracing::warn!(plugin = %self.plugin_id, fields = %fields_joined, "{}", msg)
}
Level::Error => {
tracing::error!(plugin = %self.plugin_id, fields = %fields_joined, "{}", msg)
}
}
}
}
impl plugin::clock::Host for HostState {
fn now_ms(&mut self) -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn monotonic(&mut self) -> u64 {
self.start_mono.elapsed().as_millis() as u64
}
}
impl plugin::rng::Host for HostState {
fn bytes(&mut self, len: u32) -> Vec<u8> {
use rand::RngCore;
let mut buf = vec![0u8; len as usize];
rand::thread_rng().fill_bytes(&mut buf);
buf
}
}
// ── JS Engine Host Implementation ────────────────────────────────────
impl plugin::js::Host for HostState {
fn eval_js(
&mut self,
code: String,
input: String,
) -> Result<String, plugin::common::PluginError> {
// Permission gate: plugin must have allow_js=true
if !self.manifest.allow_js {
tracing::warn!(plugin = %self.plugin_id, "JS eval blocked: allow_js=false");
return Err(plugin::common::PluginError::Forbidden);
}
match self.js_pool.eval_js(&self.plugin_id, &code, &input) {
Ok(result) => Ok(result),
Err(e) => {
tracing::warn!(plugin = %self.plugin_id, error = %e, "JS eval failed");
Err(js_error_to_plugin_error(e))
}
}
}
fn eval_js_opts(
&mut self,
code: String,
input: String,
opts: JsOpts,
) -> Result<String, plugin::common::PluginError> {
if !self.manifest.allow_js {
tracing::warn!(plugin = %self.plugin_id, "JS eval blocked: allow_js=false");
return Err(plugin::common::PluginError::Forbidden);
}
match self.js_pool.eval_js_opts(
&self.plugin_id,
&code,
&input,
opts.filename,
opts.timeout_ms,
) {
Ok(result) => Ok(result),
Err(e) => {
tracing::warn!(plugin = %self.plugin_id, error = %e, "JS eval_opts failed");
Err(js_error_to_plugin_error(e))
}
}
}
fn call_js_fn(
&mut self,
fn_name: String,
fn_source: String,
args_json: String,
) -> Result<String, plugin::common::PluginError> {
if !self.manifest.allow_js {
tracing::warn!(plugin = %self.plugin_id, "JS call blocked: allow_js=false");
return Err(plugin::common::PluginError::Forbidden);
}
match self
.js_pool
.call_js_fn(&self.plugin_id, &fn_name, &fn_source, &args_json)
{
Ok(result) => Ok(result),
Err(e) => {
tracing::warn!(
plugin = %self.plugin_id,
fn_name = %fn_name,
error = %e,
"JS call_fn failed"
);
Err(js_error_to_plugin_error(e))
}
}
}
fn clear_js_fn(
&mut self,
fn_name: String,
) -> Result<u8, plugin::common::PluginError> {
if !self.manifest.allow_js {
tracing::warn!(plugin = %self.plugin_id, "JS clear blocked: allow_js=false");
return Err(plugin::common::PluginError::Forbidden);
}
match self.js_pool.clear_js_fn(&self.plugin_id, &fn_name) {
Ok(v) => Ok(v),
Err(e) => {
tracing::warn!(
plugin = %self.plugin_id,
fn_name = %fn_name,
error = %e,
"JS clear_fn failed"
);
Err(js_error_to_plugin_error(e))
}
}
}
}
fn js_error_to_plugin_error(e: bex_js::JsError) -> plugin::common::PluginError {
use plugin::common::PluginError;
match e {
bex_js::JsError::Syntax(msg) => PluginError::Parse(format!("JS syntax: {msg}")),
bex_js::JsError::Runtime(msg) => PluginError::Internal(format!("JS runtime: {msg}")),
bex_js::JsError::Timeout(ms) => {
tracing::debug!("JS timeout after {}ms", ms);
PluginError::Timeout
}
bex_js::JsError::OutOfMemory(_) => PluginError::Internal("JS out of memory".into()),
bex_js::JsError::Execution(msg) => PluginError::Internal(format!("JS: {msg}")),
bex_js::JsError::FunctionNotFound(_name) => PluginError::NotFound,
bex_js::JsError::PermissionDenied(_msg) => PluginError::Forbidden,
bex_js::JsError::InvalidJson(msg) => {
PluginError::InvalidInput(format!("JS args: {msg}"))
}
bex_js::JsError::PoolBusy => PluginError::RateLimited(None),
bex_js::JsError::PoolShutdown => PluginError::Internal("JS pool shut down".into()),
bex_js::JsError::Internal(msg) => PluginError::Internal(format!("JS: {msg}")),
}
}