476 lines
17 KiB
Rust
476 lines
17 KiB
Rust
//! OpenRPC server implementation.
|
||
|
||
use jsonrpsee::{
|
||
core::{RpcResult, async_trait},
|
||
server::middleware::rpc::{RpcServiceT, RpcServiceBuilder, MethodResponse},
|
||
proc_macros::rpc,
|
||
server::{Server, ServerHandle},
|
||
types::{ErrorObject, ErrorObjectOwned},
|
||
};
|
||
use tower_http::cors::{CorsLayer, Any};
|
||
|
||
use anyhow;
|
||
use log::{debug, info, error};
|
||
|
||
use crate::{auth::ApiKey, supervisor::Supervisor};
|
||
use crate::error::SupervisorError;
|
||
use hero_job::{Job, JobResult, JobStatus};
|
||
use serde::{Deserialize, Serialize};
|
||
|
||
use std::net::SocketAddr;
|
||
use std::sync::Arc;
|
||
use std::fs;
|
||
use tokio::sync::Mutex;
|
||
|
||
/// Load OpenRPC specification from docs/supervisor/openrpc.json
|
||
fn load_openrpc_spec() -> Result<serde_json::Value, Box<dyn std::error::Error>> {
|
||
// Path relative to the workspace root (where Cargo.toml is)
|
||
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/../../docs/supervisor/openrpc.json");
|
||
let content = fs::read_to_string(path)?;
|
||
let spec = serde_json::from_str(&content)?;
|
||
debug!("Loaded OpenRPC specification from: {}", path);
|
||
Ok(spec)
|
||
}
|
||
|
||
/// Request parameters for generating API keys (auto-generates key value)
|
||
#[derive(Debug, Deserialize, Serialize)]
|
||
pub struct GenerateApiKeyParams {
|
||
pub name: String,
|
||
pub scope: String, // "admin", "registrar", or "user"
|
||
}
|
||
|
||
/// Job status response with metadata
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct JobStatusResponse {
|
||
pub job_id: String,
|
||
pub status: String,
|
||
pub created_at: String,
|
||
}
|
||
|
||
/// Supervisor information response
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct SupervisorInfo {
|
||
pub server_url: String,
|
||
}
|
||
|
||
/// OpenRPC trait - maps directly to Supervisor methods
|
||
/// This trait exists only for jsonrpsee's macro system.
|
||
/// The implementation below is just error type conversion -
|
||
/// all actual logic lives in Supervisor methods.
|
||
#[rpc(server)]
|
||
pub trait SupervisorRpc {
|
||
/// Create a job without queuing it to a runner
|
||
#[method(name = "job.create")]
|
||
async fn job_create(&self, params: Job) -> RpcResult<String>;
|
||
|
||
/// Get a job by job ID
|
||
#[method(name = "job.get")]
|
||
async fn job_get(&self, job_id: String) -> RpcResult<Job>;
|
||
|
||
/// Start a previously created job by queuing it to its assigned runner
|
||
#[method(name = "job.start")]
|
||
async fn job_start(&self, job_id: String) -> RpcResult<()>;
|
||
|
||
/// Run a job on the appropriate runner and return the result
|
||
#[method(name = "job.run")]
|
||
async fn job_run(&self, params: Job) -> RpcResult<JobResult>;
|
||
|
||
/// Get the current status of a job
|
||
#[method(name = "job.status")]
|
||
async fn job_status(&self, job_id: String) -> RpcResult<JobStatus>;
|
||
|
||
/// Get the result of a completed job (blocks until result is available)
|
||
#[method(name = "job.result")]
|
||
async fn job_result(&self, job_id: String) -> RpcResult<JobResult>;
|
||
|
||
/// Get logs for a specific job
|
||
#[method(name = "job.logs")]
|
||
async fn job_logs(&self, job_id: String) -> RpcResult<Vec<String>>;
|
||
|
||
/// Stop a running job
|
||
#[method(name = "job.stop")]
|
||
async fn job_stop(&self, job_id: String) -> RpcResult<()>;
|
||
|
||
/// Delete a job from the system
|
||
#[method(name = "job.delete")]
|
||
async fn job_delete(&self, job_id: String) -> RpcResult<()>;
|
||
|
||
/// List all jobs
|
||
#[method(name = "job.list")]
|
||
async fn job_list(&self) -> RpcResult<Vec<Job>>;
|
||
|
||
/// Add a runner with configuration
|
||
#[method(name = "runner.create")]
|
||
async fn runner_create(&self, runner_id: String) -> RpcResult<()>;
|
||
|
||
/// Delete a runner from the supervisor
|
||
#[method(name = "runner.remove")]
|
||
async fn runner_delete(&self, runner_id: String) -> RpcResult<()>;
|
||
|
||
/// List all runner IDs
|
||
#[method(name = "runner.list")]
|
||
async fn runner_list(&self) -> RpcResult<Vec<String>>;
|
||
|
||
/// Ping a runner (dispatch a ping job)
|
||
#[method(name = "runner.ping")]
|
||
async fn ping_runner(&self, runner_id: String) -> RpcResult<String>;
|
||
|
||
/// Create an API key with provided key value
|
||
#[method(name = "key.create")]
|
||
async fn key_create(&self, key: ApiKey) -> RpcResult<()>;
|
||
|
||
/// Generate a new API key with auto-generated key value
|
||
#[method(name = "key.generate")]
|
||
async fn key_generate(&self, params: GenerateApiKeyParams) -> RpcResult<ApiKey>;
|
||
|
||
/// Delete an API key
|
||
#[method(name = "key.delete")]
|
||
async fn key_delete(&self, key_id: String) -> RpcResult<()>;
|
||
|
||
/// List all secrets (returns counts only for security)
|
||
#[method(name = "key.list")]
|
||
async fn key_list(&self) -> RpcResult<Vec<ApiKey>>;
|
||
|
||
/// Verify an API key and return its metadata
|
||
#[method(name = "auth.verify")]
|
||
async fn auth_verify(&self) -> RpcResult<crate::auth::AuthVerifyResponse>;
|
||
|
||
/// Get supervisor information
|
||
#[method(name = "supervisor.info")]
|
||
async fn supervisor_info(&self) -> RpcResult<SupervisorInfo>;
|
||
|
||
/// OpenRPC discovery method - returns the OpenRPC document describing this API
|
||
#[method(name = "rpc.discover")]
|
||
async fn rpc_discover(&self) -> RpcResult<serde_json::Value>;
|
||
}
|
||
|
||
/// RPC implementation on Supervisor
|
||
///
|
||
/// This implementation is ONLY for error type conversion (SupervisorError → ErrorObject).
|
||
/// All business logic is in Supervisor methods - these are thin wrappers.
|
||
/// Authorization is handled by middleware before methods are called.
|
||
#[async_trait]
|
||
impl SupervisorRpcServer for Supervisor {
|
||
async fn job_create(&self, job: Job) -> RpcResult<String> {
|
||
Ok(self.job_create(job).await?)
|
||
}
|
||
|
||
async fn job_get(&self, job_id: String) -> RpcResult<Job> {
|
||
Ok(self.job_get(&job_id).await?)
|
||
}
|
||
|
||
async fn job_list(&self) -> RpcResult<Vec<Job>> {
|
||
let job_ids = self.job_list().await;
|
||
let mut jobs = Vec::new();
|
||
for job_id in job_ids {
|
||
if let Ok(job) = self.job_get(&job_id).await {
|
||
jobs.push(job);
|
||
}
|
||
}
|
||
Ok(jobs)
|
||
}
|
||
|
||
async fn job_run(&self, job: Job) -> RpcResult<JobResult> {
|
||
let output = self.job_run(job).await?;
|
||
Ok(JobResult::Success { success: output })
|
||
}
|
||
|
||
async fn job_start(&self, job_id: String) -> RpcResult<()> {
|
||
self.job_start(&job_id).await?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn job_status(&self, job_id: String) -> RpcResult<JobStatus> {
|
||
Ok(self.job_status(&job_id).await?)
|
||
}
|
||
|
||
async fn job_logs(&self, job_id: String) -> RpcResult<Vec<String>> {
|
||
Ok(self.job_logs(&job_id, None).await?)
|
||
}
|
||
|
||
async fn job_result(&self, job_id: String) -> RpcResult<JobResult> {
|
||
match self.job_result(&job_id).await? {
|
||
Some(result) => {
|
||
if result.starts_with("Error:") {
|
||
Ok(JobResult::Error { error: result })
|
||
} else {
|
||
Ok(JobResult::Success { success: result })
|
||
}
|
||
},
|
||
None => Ok(JobResult::Error { error: "Job result not available".to_string() })
|
||
}
|
||
}
|
||
|
||
async fn job_stop(&self, job_id: String) -> RpcResult<()> {
|
||
self.job_stop(&job_id).await?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn job_delete(&self, job_id: String) -> RpcResult<()> {
|
||
self.job_delete(&job_id).await?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn runner_create(&self, runner_id: String) -> RpcResult<()> {
|
||
self.runner_create(runner_id).await?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn runner_delete(&self, runner_id: String) -> RpcResult<()> {
|
||
Ok(self.runner_delete(&runner_id).await?)
|
||
}
|
||
|
||
async fn runner_list(&self) -> RpcResult<Vec<String>> {
|
||
Ok(self.runner_list().await)
|
||
}
|
||
|
||
|
||
async fn ping_runner(&self, runner_id: String) -> RpcResult<String> {
|
||
Ok(self.runner_ping(&runner_id).await?)
|
||
}
|
||
|
||
async fn key_create(&self, key: ApiKey) -> RpcResult<()> {
|
||
let _ = self.key_create(key).await;
|
||
Ok(())
|
||
}
|
||
|
||
async fn key_generate(&self, params: GenerateApiKeyParams) -> RpcResult<ApiKey> {
|
||
// Parse scope
|
||
let api_scope = match params.scope.to_lowercase().as_str() {
|
||
"admin" => crate::auth::ApiKeyScope::Admin,
|
||
"registrar" => crate::auth::ApiKeyScope::Registrar,
|
||
"user" => crate::auth::ApiKeyScope::User,
|
||
_ => return Err(ErrorObject::owned(-32602, "Invalid scope. Must be 'admin', 'registrar', or 'user'", None::<()>)),
|
||
};
|
||
|
||
let api_key = self.create_api_key(params.name, api_scope).await;
|
||
Ok(api_key)
|
||
}
|
||
|
||
async fn key_delete(&self, key_id: String) -> RpcResult<()> {
|
||
self.key_delete(&key_id).await
|
||
.ok_or_else(|| ErrorObject::owned(-32603, "API key not found", None::<()>))?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn key_list(&self) -> RpcResult<Vec<ApiKey>> {
|
||
Ok(self.key_list().await)
|
||
}
|
||
|
||
async fn auth_verify(&self) -> RpcResult<crate::auth::AuthVerifyResponse> {
|
||
// If this method is called, middleware already verified the key
|
||
// So we just return success - the middleware wouldn't have let an invalid key through
|
||
Ok(crate::auth::AuthVerifyResponse {
|
||
valid: true,
|
||
name: "verified".to_string(),
|
||
scope: "authenticated".to_string(),
|
||
})
|
||
}
|
||
|
||
async fn supervisor_info(&self) -> RpcResult<SupervisorInfo> {
|
||
Ok(SupervisorInfo {
|
||
server_url: "http://127.0.0.1:3031".to_string(), // TODO: get from config
|
||
})
|
||
}
|
||
|
||
async fn rpc_discover(&self) -> RpcResult<serde_json::Value> {
|
||
debug!("OpenRPC request: rpc.discover");
|
||
|
||
// Read OpenRPC specification from docs/openrpc.json
|
||
match load_openrpc_spec() {
|
||
Ok(spec) => Ok(spec),
|
||
Err(e) => {
|
||
error!("Failed to load OpenRPC specification: {}", e);
|
||
// Fallback to a minimal spec if file loading fails
|
||
Ok(serde_json::json!({
|
||
"openrpc": "1.3.2",
|
||
"info": {
|
||
"title": "Hero Supervisor OpenRPC API",
|
||
"version": "1.0.0",
|
||
"description": "OpenRPC API for managing Hero Supervisor runners and jobs"
|
||
},
|
||
"methods": [],
|
||
"error": "Failed to load full specification"
|
||
}))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Authorization middleware using RpcServiceT
|
||
/// This middleware is created per-connection and checks permissions for each RPC call
|
||
#[derive(Clone)]
|
||
struct AuthMiddleware<S> {
|
||
supervisor: Supervisor,
|
||
inner: S,
|
||
}
|
||
|
||
impl<S> RpcServiceT for AuthMiddleware<S>
|
||
where
|
||
S: RpcServiceT<MethodResponse = MethodResponse> + Send + Sync + Clone + 'static,
|
||
{
|
||
type MethodResponse = MethodResponse;
|
||
type BatchResponse = S::BatchResponse;
|
||
type NotificationResponse = S::NotificationResponse;
|
||
|
||
fn call<'a>(&self, req: jsonrpsee::server::middleware::rpc::Request<'a>) -> impl std::future::Future<Output = Self::MethodResponse> + Send + 'a {
|
||
let supervisor = self.supervisor.clone();
|
||
let inner = self.inner.clone();
|
||
let method = req.method_name().to_string();
|
||
let id = req.id();
|
||
|
||
Box::pin(async move {
|
||
// Check if method requires auth
|
||
let required_scopes = match crate::auth::get_method_required_scopes(&method) {
|
||
None => {
|
||
// Public method - no auth required
|
||
debug!("ℹ️ Public method: {}", method);
|
||
return inner.call(req).await;
|
||
}
|
||
Some(scopes) => scopes,
|
||
};
|
||
|
||
// Extract Authorization header from extensions
|
||
let headers = req.extensions().get::<hyper::HeaderMap>();
|
||
|
||
let api_key = headers
|
||
.and_then(|h| h.get(hyper::header::AUTHORIZATION))
|
||
.and_then(|value| value.to_str().ok())
|
||
.and_then(|s| s.strip_prefix("Bearer "))
|
||
.map(|k| k.to_string());
|
||
|
||
let api_key = match api_key {
|
||
Some(key) => key,
|
||
None => {
|
||
error!("❌ Missing Authorization header for method: {}", method);
|
||
let err = ErrorObjectOwned::owned(
|
||
-32001,
|
||
format!("Missing Authorization header for method: {}", method),
|
||
None::<()>,
|
||
);
|
||
return MethodResponse::error(id, err);
|
||
}
|
||
};
|
||
|
||
// Verify API key and check scope
|
||
let key_obj = match supervisor.key_get(&api_key).await {
|
||
Some(k) => k,
|
||
None => {
|
||
error!("❌ Invalid API key");
|
||
let err = ErrorObjectOwned::owned(-32001, "Invalid API key", None::<()>);
|
||
return MethodResponse::error(id, err);
|
||
}
|
||
};
|
||
|
||
if !required_scopes.contains(&key_obj.scope) {
|
||
error!(
|
||
"❌ Unauthorized: method '{}' requires {:?}, got {:?}",
|
||
method, required_scopes, key_obj.scope
|
||
);
|
||
let err = ErrorObjectOwned::owned(
|
||
-32001,
|
||
format!(
|
||
"Insufficient permissions for '{}'. Required: {:?}, Got: {:?}",
|
||
method, required_scopes, key_obj.scope
|
||
),
|
||
None::<()>,
|
||
);
|
||
return MethodResponse::error(id, err);
|
||
}
|
||
|
||
debug!("✅ Authorized: {} with scope {:?}", method, key_obj.scope);
|
||
|
||
// Authorized - proceed with the call
|
||
inner.call(req).await
|
||
})
|
||
}
|
||
|
||
fn batch<'a>(&self, batch: jsonrpsee::server::middleware::rpc::Batch<'a>) -> impl std::future::Future<Output = Self::BatchResponse> + Send + 'a {
|
||
// For simplicity, pass through batch requests
|
||
// In production, you'd want to check each request in the batch
|
||
self.inner.batch(batch)
|
||
}
|
||
|
||
fn notification<'a>(&self, notif: jsonrpsee::server::middleware::rpc::Notification<'a>) -> impl std::future::Future<Output = Self::NotificationResponse> + Send + 'a {
|
||
self.inner.notification(notif)
|
||
}
|
||
}
|
||
|
||
/// HTTP middleware to propagate headers into request extensions
|
||
#[derive(Clone)]
|
||
struct HeaderPropagationService<S> {
|
||
inner: S,
|
||
}
|
||
|
||
impl<S, B> tower::Service<hyper::Request<B>> for HeaderPropagationService<S>
|
||
where
|
||
S: tower::Service<hyper::Request<B>> + Clone + Send + 'static,
|
||
S::Future: Send + 'static,
|
||
B: Send + 'static,
|
||
{
|
||
type Response = S::Response;
|
||
type Error = S::Error;
|
||
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||
|
||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
|
||
self.inner.poll_ready(cx)
|
||
}
|
||
|
||
fn call(&mut self, mut req: hyper::Request<B>) -> Self::Future {
|
||
let headers = req.headers().clone();
|
||
req.extensions_mut().insert(headers);
|
||
let fut = self.inner.call(req);
|
||
Box::pin(fut)
|
||
}
|
||
}
|
||
|
||
/// Start HTTP OpenRPC server (Unix socket support would require additional dependencies)
|
||
pub async fn start_http_openrpc_server(
|
||
supervisor: Supervisor,
|
||
bind_address: &str,
|
||
port: u16,
|
||
) -> anyhow::Result<ServerHandle> {
|
||
let http_addr: SocketAddr = format!("{}:{}", bind_address, port).parse()?;
|
||
|
||
// Configure CORS to allow requests from the admin UI
|
||
// Note: Authorization header must be explicitly listed, not covered by Any
|
||
use tower_http::cors::AllowHeaders;
|
||
let cors = CorsLayer::new()
|
||
.allow_origin(Any)
|
||
.allow_headers(AllowHeaders::list([
|
||
hyper::header::CONTENT_TYPE,
|
||
hyper::header::AUTHORIZATION,
|
||
]))
|
||
.allow_methods(Any)
|
||
.expose_headers(Any);
|
||
|
||
// Build RPC middleware with authorization (per-connection)
|
||
let supervisor_for_middleware = supervisor.clone();
|
||
let rpc_middleware = RpcServiceBuilder::new().layer_fn(move |service| {
|
||
// This closure runs once per connection
|
||
AuthMiddleware {
|
||
supervisor: supervisor_for_middleware.clone(),
|
||
inner: service,
|
||
}
|
||
});
|
||
|
||
// Build HTTP middleware stack with CORS and header propagation
|
||
let http_middleware = tower::ServiceBuilder::new()
|
||
.layer(cors)
|
||
.layer(tower::layer::layer_fn(|service| {
|
||
HeaderPropagationService { inner: service }
|
||
}));
|
||
|
||
let http_server = Server::builder()
|
||
.set_rpc_middleware(rpc_middleware)
|
||
.set_http_middleware(http_middleware)
|
||
.build(http_addr)
|
||
.await?;
|
||
|
||
let http_handle = http_server.start(supervisor.into_rpc());
|
||
|
||
info!("OpenRPC HTTP server running at http://{} with CORS enabled", http_addr);
|
||
|
||
Ok(http_handle)
|
||
}
|