krystv's picture
Upload 107 files
3374e90 verified
//! Worker thread implementation for the JS pool.
use crate::config::JsPoolConfig;
use crate::error::JsError;
use crate::polyfills;
use crate::pool::simple_hash;
use crossbeam_channel::{Receiver, Sender};
use rquickjs::context::EvalOptions;
use rquickjs::promise::PromiseState;
use rquickjs::{CatchResultExt, Context, Runtime};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
pub struct JsTask {
pub plugin_id: String,
pub kind: JsTaskKind,
pub reply: Sender<JsResult>,
}
pub enum JsTaskKind {
Eval {
code: String,
input: String,
filename: Option<String>,
timeout_ms: u32,
},
CallFn {
name: String,
fn_source: String,
args_json: String,
timeout_ms: u32,
},
ClearFn {
name: String,
},
Evict,
}
pub struct JsResult {
pub result: Result<String, JsError>,
}
struct PluginContext {
ctx: Context,
/// Maps function name → source hash (u64). Used for change detection.
/// When call-js-fn is called with a fn_source whose hash differs,
/// the function is automatically re-registered.
functions: HashMap<String, u64>,
last_access: std::time::Instant,
}
pub struct JsWorker {
id: usize,
config: JsPoolConfig,
rx: Receiver<JsTask>,
shutdown: Arc<AtomicBool>,
contexts: HashMap<String, PluginContext>,
runtime: Runtime,
}
impl JsWorker {
pub fn spawn(
id: usize,
config: JsPoolConfig,
rx: Receiver<JsTask>,
shutdown: Arc<AtomicBool>,
) -> std::thread::JoinHandle<()> {
let rt = Runtime::new().expect("Failed to create QuickJS runtime");
rt.set_memory_limit(config.memory_limit_bytes);
rt.set_max_stack_size(config.max_stack_bytes);
let worker = Self {
id,
config,
rx,
shutdown,
contexts: HashMap::new(),
runtime: rt,
};
std::thread::Builder::new()
.name(format!("bex-js-worker-{}", id))
.spawn(move || worker.run())
.expect("Failed to spawn JS worker thread")
}
fn run(mut self) {
tracing::debug!(worker = self.id, "JS worker started");
let mut last_evict = std::time::Instant::now();
let evict_interval = std::time::Duration::from_secs(30);
while !self.shutdown.load(Ordering::Acquire) {
// Only check for idle contexts periodically (not every loop iteration)
if last_evict.elapsed() > evict_interval {
self.evict_idle_contexts();
self.runtime.run_gc();
last_evict = std::time::Instant::now();
}
match self.rx.recv_timeout(std::time::Duration::from_secs(5)) {
Ok(task) => self.handle_task(task),
Err(crossbeam_channel::RecvTimeoutError::Timeout) => continue,
Err(crossbeam_channel::RecvTimeoutError::Disconnected) => break,
}
}
tracing::debug!(worker = self.id, "JS worker shutting down");
self.contexts.clear();
}
fn handle_task(&mut self, task: JsTask) {
let result = match task.kind {
JsTaskKind::Eval {
code,
input,
filename,
timeout_ms,
} => self.eval_js(&task.plugin_id, &code, &input, filename, timeout_ms),
JsTaskKind::CallFn {
name,
fn_source,
args_json,
timeout_ms,
} => self.call_fn(&task.plugin_id, &name, &fn_source, &args_json, timeout_ms),
JsTaskKind::ClearFn { name } => {
if let Some(pctx) = self.contexts.get_mut(&task.plugin_id) {
pctx.functions.remove(&name);
pctx.ctx.with(|ctx| {
let _ = ctx.globals().remove(&name as &str);
});
}
Ok("ok".to_string())
}
JsTaskKind::Evict => {
self.contexts.remove(&task.plugin_id);
Ok("evicted".into())
}
};
let _ = task.reply.send(JsResult { result });
}
fn ensure_context(&mut self, plugin_id: &str) {
if !self.contexts.contains_key(plugin_id) {
let ctx = Context::full(&self.runtime).expect("Failed to create JS context");
ctx.with(|ctx| {
let _ = polyfills::install(&ctx);
});
self.contexts.insert(
plugin_id.to_string(),
PluginContext {
ctx,
functions: HashMap::new(),
last_access: std::time::Instant::now(),
},
);
}
if let Some(pc) = self.contexts.get_mut(plugin_id) {
pc.last_access = std::time::Instant::now();
}
}
fn eval_js(
&mut self,
plugin_id: &str,
code: &str,
input: &str,
filename: Option<String>,
timeout_ms: u32,
) -> Result<String, JsError> {
self.ensure_context(plugin_id);
let timeout = if timeout_ms > 0 {
timeout_ms
} else {
self.config.default_timeout_ms
};
let mem_limit_mb = (self.config.memory_limit_bytes / (1024 * 1024)) as u32;
// Safe input injection — NOT eval, just set a global string
{
let pctx = self.contexts.get(plugin_id).unwrap();
pctx.ctx.with(|ctx| -> rquickjs::Result<()> {
ctx.globals().set("input", input)?;
Ok(())
}).map_err(|e| JsError::Execution(e.to_string()))?;
}
let start = std::time::Instant::now();
self.runtime.set_interrupt_handler(Some(Box::new(move || {
start.elapsed().as_millis() as u32 > timeout
})));
let result = self
.contexts
.get(plugin_id)
.unwrap()
.ctx
.with(|ctx| {
let mut opts = EvalOptions::default();
if let Some(ref fname) = filename {
opts.filename = Some(fname.clone());
}
let eval_result: rquickjs::Value = ctx
.eval_with_options(code, opts)
.catch(&ctx)
.map_err(|e| js_err_from_caught(e, timeout, mem_limit_mb))?;
// If the eval result is a Promise, resolve it by flushing pending
// microtasks. This is essential for async functions like
// crypto.subtle.digest which are synchronous under the hood but
// return Promises due to the async keyword.
let final_result = if eval_result.is_promise() {
// Flush pending jobs to resolve the Promise
while ctx.execute_pending_job() {}
// Get the resolved value from the Promise
if let Some(promise) = eval_result.as_promise() {
match promise.state() {
PromiseState::Resolved => {
match promise.result::<rquickjs::Value>() {
Some(Ok(val)) => val,
_ => eval_result.clone(),
}
}
PromiseState::Rejected => {
// Get the rejection reason
let reason = if let Some(promise) = eval_result.as_promise() {
match promise.result::<rquickjs::Value>() {
Some(Ok(v)) => {
if v.is_string() {
v.as_string().and_then(|s| s.to_string().ok()).unwrap_or_else(|| "unknown".to_string())
} else {
format!("{:?}", v)
}
}
_ => "Promise rejected".to_string(),
}
} else {
"Promise rejected".to_string()
};
return Err(JsError::Execution(format!("Promise rejected: {}", reason)));
}
PromiseState::Pending => eval_result.clone(),
}
} else {
eval_result.clone()
}
} else {
eval_result.clone()
};
let json_str = if final_result.is_undefined() {
"null".to_string()
} else {
match ctx.json_stringify(&final_result) {
Ok(Some(s)) => s
.to_string()
.map_err(|e| JsError::Execution(e.to_string()))?,
Ok(None) => "null".to_string(),
Err(_) => "null".to_string(),
}
};
Ok::<String, JsError>(json_str)
});
// Flush any remaining pending Promise microtasks (e.g. from .then() chains
// that were set up but not awaited). This ensures side effects from Promise
// callbacks are visible in subsequent eval calls.
while self.runtime.execute_pending_job().unwrap_or(false) {}
self.runtime.set_interrupt_handler(None);
result
}
fn call_fn(
&mut self,
plugin_id: &str,
name: &str,
fn_source: &str,
args_json: &str,
timeout_ms: u32,
) -> Result<String, JsError> {
self.ensure_context(plugin_id);
// Auto-register or re-register the function based on source hash
let src_hash = simple_hash(fn_source);
let needs_register = self
.contexts
.get(plugin_id)
.unwrap()
.functions
.get(name)
.map_or(true, |&h| h != src_hash);
if needs_register {
let pctx = self.contexts.get(plugin_id).unwrap();
pctx.ctx.with(|ctx| {
ctx.eval::<(), _>(fn_source)
.catch(&ctx)
.map_err(|e| {
JsError::Execution(format!("fn registration '{}': {}", name, e))
})
})?;
self.contexts
.get_mut(plugin_id)
.unwrap()
.functions
.insert(name.to_string(), src_hash);
}
let timeout = if timeout_ms > 0 {
timeout_ms
} else {
self.config.default_timeout_ms
};
let mem_limit_mb = (self.config.memory_limit_bytes / (1024 * 1024)) as u32;
let start = std::time::Instant::now();
self.runtime.set_interrupt_handler(Some(Box::new(move || {
start.elapsed().as_millis() as u32 > timeout
})));
let name_owned = name.to_string();
let result = self
.contexts
.get(plugin_id)
.unwrap()
.ctx
.with(|ctx| {
let func: rquickjs::Function = ctx
.globals()
.get(name_owned.clone())
.map_err(|_| JsError::FunctionNotFound(name_owned.clone()))?;
// SAFE: pass args_json as a string — NO eval of user data
// The JS function receives it as a string and calls JSON.parse(args) internally
let result_val: rquickjs::Value = func
.call((args_json,))
.catch(&ctx)
.map_err(|e| js_err_from_caught(e, timeout, mem_limit_mb))?;
let json_str = if result_val.is_undefined() {
"null".to_string()
} else {
match ctx.json_stringify(&result_val) {
Ok(Some(s)) => s
.to_string()
.map_err(|e| JsError::Execution(e.to_string()))?,
Ok(None) => "null".to_string(),
Err(_) => "null".to_string(),
}
};
Ok::<String, JsError>(json_str)
});
self.runtime.set_interrupt_handler(None);
result
}
fn evict_idle_contexts(&mut self) {
let ttl = std::time::Duration::from_secs(self.config.context_idle_ttl_secs);
let now = std::time::Instant::now();
self.contexts.retain(|pid, ctx| {
if now.duration_since(ctx.last_access) > ttl {
tracing::debug!(worker = self.id, plugin = %pid, "Evicting idle JS context");
false
} else {
true
}
});
}
}
fn js_err_from_caught(e: rquickjs::CaughtError, timeout_ms: u32, mem_limit_mb: u32) -> JsError {
let msg = format!("{}", e);
if msg.contains("interrupted") || msg.contains("timeout") {
JsError::Timeout(timeout_ms)
} else if msg.contains("memory") || msg.contains("alloc") {
JsError::OutOfMemory(mem_limit_mb)
} else if msg.contains("SyntaxError") || msg.contains("syntax") {
JsError::Syntax(msg)
} else {
JsError::Execution(msg)
}
}