...
This commit is contained in:
176
herolib/infra/tmuxrunner/task_runner_api.py
Normal file
176
herolib/infra/tmuxrunner/task_runner_api.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import sys
|
||||
import toml
|
||||
import libtmux
|
||||
from libtmux.pane import Pane
|
||||
from libtmux.window import Window
|
||||
from libtmux.session import Session
|
||||
import psutil
|
||||
from typing import Dict, List, Optional, Any, Set, Tuple
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
|
||||
|
||||
|
||||
class TaskRunnerAPI:
|
||||
"""FastAPI interface for the task runner."""
|
||||
|
||||
def __init__(self, runner: EnhancedTaskRunner):
|
||||
self.runner = runner
|
||||
self.app = FastAPI(title="Task Runner API", version="1.0.0")
|
||||
|
||||
# Add CORS middleware
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup API routes."""
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
"""Get API information."""
|
||||
return {
|
||||
"name": "Task Runner API",
|
||||
"version": "1.0.0",
|
||||
"run_id": self.runner.run_id,
|
||||
"run_name": self.runner.run_name
|
||||
}
|
||||
|
||||
@self.app.get("/status")
|
||||
async def get_status():
|
||||
"""Get current run status."""
|
||||
return {
|
||||
"run_id": self.runner.run_id,
|
||||
"run_name": self.runner.run_name,
|
||||
"state": self.runner.dag.state,
|
||||
"start_time": self.runner.dag.start_time,
|
||||
"end_time": self.runner.dag.end_time,
|
||||
"duration_seconds": self.runner.dag.duration_seconds,
|
||||
"total_directories": self.runner.dag.total_directories,
|
||||
"completed_directories": self.runner.dag.completed_directories,
|
||||
"failed_directories": self.runner.dag.failed_directories
|
||||
}
|
||||
|
||||
@self.app.get("/directories")
|
||||
async def get_directories():
|
||||
"""Get all directory statuses."""
|
||||
return [
|
||||
{
|
||||
"directory_num": d.directory_num,
|
||||
"directory_path": d.directory_path,
|
||||
"state": d.state,
|
||||
"timeout": d.timeout,
|
||||
"start_time": d.start_time,
|
||||
"end_time": d.end_time,
|
||||
"duration_seconds": d.duration_seconds,
|
||||
"task_count": len(d.tasks),
|
||||
"tasks_done": sum(1 for t in d.tasks if t.state == "DONE"),
|
||||
"tasks_error": sum(1 for t in d.tasks if t.state in ["ERROR", "CRASHED", "TIMED_OUT"])
|
||||
}
|
||||
for d in self.runner.dag.directories
|
||||
]
|
||||
|
||||
@self.app.get("/directories/{dir_num}/tasks")
|
||||
async def get_directory_tasks(dir_num: int):
|
||||
"""Get tasks for a specific directory."""
|
||||
for d in self.runner.dag.directories:
|
||||
if d.directory_num == dir_num:
|
||||
return d.tasks
|
||||
raise HTTPException(status_code=404, detail="Directory not found")
|
||||
|
||||
@self.app.get("/tasks/{dir_num}/{task_name}")
|
||||
async def get_task_details(dir_num: int, task_name: str):
|
||||
"""Get detailed information about a specific task."""
|
||||
for d in self.runner.dag.directories:
|
||||
if d.directory_num == dir_num:
|
||||
for t in d.tasks:
|
||||
if t.script_name == task_name:
|
||||
return t
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
@self.app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""Get current process metrics for all running tasks."""
|
||||
metrics = []
|
||||
for d in self.runner.dag.directories:
|
||||
for t in d.tasks:
|
||||
if t.state == "RUNNING":
|
||||
metrics.append({
|
||||
"directory": d.directory_num,
|
||||
"task": t.script_name,
|
||||
"cpu_percent": t.process_metrics.cpu_percent,
|
||||
"memory_rss_mb": t.process_metrics.memory_rss / (1024 * 1024),
|
||||
"memory_percent": t.process_metrics.memory_percent,
|
||||
"num_threads": t.process_metrics.num_threads,
|
||||
"num_children": t.process_metrics.num_children
|
||||
})
|
||||
return metrics
|
||||
|
||||
@self.app.get("/dag")
|
||||
async def get_full_dag():
|
||||
"""Get the complete DAG structure."""
|
||||
return asdict(self.runner.dag)
|
||||
|
||||
def start(self, host: str = "0.0.0.0", port: int = 8000):
|
||||
"""Start the FastAPI server."""
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
|
||||
class TaskOrchestrator:
|
||||
"""Main orchestrator that runs tasks and API server."""
|
||||
|
||||
def __init__(self, tasks_dir: str, api_port: int = 8000):
|
||||
self.runner = EnhancedTaskRunner(tasks_dir)
|
||||
self.api = TaskRunnerAPI(self.runner)
|
||||
self.api_thread = None
|
||||
|
||||
def start_api_server(self, port: int = 8000):
|
||||
"""Start API server in a separate thread."""
|
||||
self.api_thread = threading.Thread(
|
||||
target=self.api.start,
|
||||
args=("0.0.0.0", port),
|
||||
daemon=True
|
||||
)
|
||||
self.api_thread.start()
|
||||
print(f"API server started on http://0.0.0.0:{port}")
|
||||
|
||||
def run(self):
|
||||
"""Run the task orchestration."""
|
||||
# Start API server
|
||||
self.start_api_server()
|
||||
|
||||
# Reset and run tasks
|
||||
self.runner.reset()
|
||||
try:
|
||||
self.runner.run()
|
||||
except Exception as e:
|
||||
print(f"Error during execution: {e}")
|
||||
self.runner.dag.state = "FAILED"
|
||||
self.runner.dag.end_time = datetime.now().isoformat()
|
||||
self.runner._save_dag()
|
||||
|
||||
print("\nExecution completed. API server still running.")
|
||||
print("Press Ctrl+C to stop the API server.")
|
||||
|
||||
try:
|
||||
# Keep the main thread alive for API access
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down...")
|
Reference in New Issue
Block a user