167 lines
6.2 KiB
Python
167 lines
6.2 KiB
Python
import threading
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
from dataclasses import asdict
|
|
from datetime import datetime
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from .task_runner import TaskRunner
|
|
|
|
|
|
class TaskRunnerAPI:
|
|
"""FastAPI interface for the task runner."""
|
|
|
|
def __init__(self, runner: TaskRunner):
|
|
self.runner = runner
|
|
self.app = FastAPI(title="Task Runner API", version="1.0.0")
|
|
|
|
# Add CORS middleware
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
self._setup_routes()
|
|
|
|
def _setup_routes(self):
|
|
"""Setup API routes."""
|
|
|
|
@self.app.get("/")
|
|
async def root():
|
|
"""Get API information."""
|
|
return {
|
|
"name": "Task Runner API",
|
|
"version": "1.0.0",
|
|
"run_id": self.runner.run_id,
|
|
"run_name": self.runner.run_name
|
|
}
|
|
|
|
@self.app.get("/status")
|
|
async def get_status():
|
|
"""Get current run status."""
|
|
return {
|
|
"run_id": self.runner.run_id,
|
|
"run_name": self.runner.run_name,
|
|
"state": self.runner.dag.state,
|
|
"start_time": self.runner.dag.start_time,
|
|
"end_time": self.runner.dag.end_time,
|
|
"duration_seconds": self.runner.dag.duration_seconds,
|
|
"total_directories": self.runner.dag.total_directories,
|
|
"completed_directories": self.runner.dag.completed_directories,
|
|
"failed_directories": self.runner.dag.failed_directories
|
|
}
|
|
|
|
@self.app.get("/directories")
|
|
async def get_directories():
|
|
"""Get all directory statuses."""
|
|
return [
|
|
{
|
|
"directory_num": d.directory_num,
|
|
"directory_path": d.directory_path,
|
|
"state": d.state,
|
|
"timeout": d.timeout,
|
|
"start_time": d.start_time,
|
|
"end_time": d.end_time,
|
|
"duration_seconds": d.duration_seconds,
|
|
"task_count": len(d.tasks),
|
|
"tasks_done": sum(1 for t in d.tasks if t.state == "DONE"),
|
|
"tasks_error": sum(1 for t in d.tasks if t.state in ["ERROR", "CRASHED", "TIMED_OUT"])
|
|
}
|
|
for d in self.runner.dag.directories
|
|
]
|
|
|
|
@self.app.get("/directories/{dir_num}/tasks")
|
|
async def get_directory_tasks(dir_num: int):
|
|
"""Get tasks for a specific directory."""
|
|
for d in self.runner.dag.directories:
|
|
if d.directory_num == dir_num:
|
|
return d.tasks
|
|
raise HTTPException(status_code=404, detail="Directory not found")
|
|
|
|
@self.app.get("/tasks/{dir_num}/{task_name}")
|
|
async def get_task_details(dir_num: int, task_name: str):
|
|
"""Get detailed information about a specific task."""
|
|
for d in self.runner.dag.directories:
|
|
if d.directory_num == dir_num:
|
|
for t in d.tasks:
|
|
if t.script_name == task_name:
|
|
return t
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
@self.app.get("/metrics")
|
|
async def get_metrics():
|
|
"""Get current process metrics for all running tasks."""
|
|
metrics = []
|
|
for d in self.runner.dag.directories:
|
|
for t in d.tasks:
|
|
if t.state == "RUNNING":
|
|
metrics.append({
|
|
"directory": d.directory_num,
|
|
"task": t.script_name,
|
|
"cpu_percent": t.process_metrics.cpu_percent,
|
|
"memory_rss_mb": t.process_metrics.memory_rss / (1024 * 1024),
|
|
"memory_percent": t.process_metrics.memory_percent,
|
|
"num_threads": t.process_metrics.num_threads,
|
|
"num_children": t.process_metrics.num_children
|
|
})
|
|
return metrics
|
|
|
|
@self.app.get("/dag")
|
|
async def get_full_dag():
|
|
"""Get the complete DAG structure."""
|
|
return asdict(self.runner.dag)
|
|
|
|
def start(self, host: str = "0.0.0.0", port: int = 8000):
|
|
"""Start the FastAPI server."""
|
|
uvicorn.run(self.app, host=host, port=port)
|
|
|
|
|
|
class TaskOrchestrator:
|
|
"""Main orchestrator that runs tasks and API server."""
|
|
|
|
def __init__(self, tasks_dir: str, api_port: int = 8000):
|
|
self.runner = TaskRunner(tasks_dir)
|
|
self.api = TaskRunnerAPI(self.runner)
|
|
self.api_thread = None
|
|
self.api_port = api_port
|
|
|
|
def start_api_server(self, port: int = 8000):
|
|
"""Start API server in a separate thread."""
|
|
self.api_thread = threading.Thread(
|
|
target=self.api.start,
|
|
args=("0.0.0.0", port),
|
|
daemon=True
|
|
)
|
|
self.api_thread.start()
|
|
print(f"API server started on http://0.0.0.0:{port}")
|
|
|
|
def run(self):
|
|
"""Run the task orchestration."""
|
|
# Start API server
|
|
self.start_api_server(self.api_port)
|
|
|
|
# Reset and run tasks
|
|
self.runner.reset()
|
|
try:
|
|
self.runner.run()
|
|
except Exception as e:
|
|
print(f"Error during execution: {e}")
|
|
self.runner.dag.state = "FAILED"
|
|
self.runner.dag.end_time = datetime.now().isoformat()
|
|
self.runner._save_dag()
|
|
|
|
print("\nExecution completed. API server still running.")
|
|
print(f"Access API at: http://localhost:{self.api_port}")
|
|
print("Press Ctrl+C to stop the API server.")
|
|
|
|
try:
|
|
# Keep the main thread alive for API access
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
print("\nShutting down...") |