Files
ss-tools/backend/src/core/task_manager/persistence.py

542 lines
23 KiB
Python

# [DEF:TaskPersistenceModule:Module]
# @TIER: CRITICAL
# @SEMANTICS: persistence, sqlite, sqlalchemy, task, storage
# @PURPOSE: Handles the persistence of tasks using SQLAlchemy and the tasks.db database.
# @LAYER: Core
# @PRE: Tasks database must be initialized with TaskRecord and TaskLogRecord schemas.
# @POST: Provides reliable storage and retrieval for task metadata and logs.
# @SIDE_EFFECT: Performs database I/O on tasks.db.
# @DATA_CONTRACT: Input[Task, LogEntry] -> Model[TaskRecord, TaskLogRecord]
# @RELATION: [USED_BY] ->[backend.src.core.task_manager.manager.TaskManager]
# @RELATION: [DEPENDS_ON] ->[backend.src.core.database.TasksSessionLocal]
# @INVARIANT: Database schema must match the TaskRecord model structure.
# [SECTION: IMPORTS]
from datetime import datetime
from typing import List, Optional
import json
import re
from sqlalchemy.orm import Session
from ...models.task import TaskRecord, TaskLogRecord
from ...models.mapping import Environment
from ..database import TasksSessionLocal
from .models import Task, TaskStatus, LogEntry, TaskLog, LogFilter, LogStats
from ..logger import logger, belief_scope
# [/SECTION]
# [DEF:TaskPersistenceService:Class]
# @TIER: CRITICAL
# @SEMANTICS: persistence, service, database, sqlalchemy
# @PURPOSE: Provides methods to save and load tasks from the tasks.db database using SQLAlchemy.
# @RELATION: [DEPENDS_ON] ->[backend.src.core.database.TasksSessionLocal]
# @RELATION: [DEPENDS_ON] ->[backend.src.models.task.TaskRecord]
# @RELATION: [DEPENDS_ON] ->[backend.src.models.mapping.Environment]
# @RELATION: [USED_BY] ->[backend.src.core.task_manager.manager.TaskManager]
# @INVARIANT: Persistence must handle potentially missing task fields natively.
#
# @TEST_CONTRACT: TaskPersistenceService ->
# {
# required_fields: {},
# invariants: [
# "persist_task creates or updates a record",
# "load_tasks retrieves valid Task instances",
# "delete_tasks correctly removes records from the database"
# ]
# }
# @TEST_FIXTURE: valid_task_persistence -> {"task_id": "123", "status": "PENDING"}
# @TEST_EDGE: persist_invalid_task_type -> raises Exception
# @TEST_EDGE: load_corrupt_json_params -> handled gracefully
# @TEST_INVARIANT: accurate_round_trip -> verifies: [valid_task_persistence, load_corrupt_json_params]
class TaskPersistenceService:
# [DEF:_json_load_if_needed:Function]
# @TIER: TRIVIAL
# @PURPOSE: Safely load JSON strings from DB if necessary
# @PRE: value is an arbitrary database value
# @POST: Returns parsed JSON object, list, string, or primitive
@staticmethod
def _json_load_if_needed(value):
with belief_scope("TaskPersistenceService._json_load_if_needed"):
if value is None:
return None
if isinstance(value, (dict, list)):
return value
if isinstance(value, str):
stripped = value.strip()
if stripped == "" or stripped.lower() == "null":
return None
try:
return json.loads(stripped)
except json.JSONDecodeError:
return value
return value
# [/DEF:_json_load_if_needed:Function]
# [DEF:_parse_datetime:Function]
# @TIER: TRIVIAL
# @PURPOSE: Safely parse a datetime string from the database
# @PRE: value is an ISO string or datetime object
# @POST: Returns datetime object or None
@staticmethod
def _parse_datetime(value):
with belief_scope("TaskPersistenceService._parse_datetime"):
if value is None or isinstance(value, datetime):
return value
if isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError:
return None
return None
# [/DEF:_parse_datetime:Function]
# [DEF:_resolve_environment_id:Function]
# @TIER: STANDARD
# @PURPOSE: Resolve environment id into existing environments.id value to satisfy FK constraints.
# @PRE: Session is active
# @POST: Returns existing environments.id or None when unresolved.
# @DATA_CONTRACT: Input[env_id: Optional[str]] -> Output[Optional[str]]
@staticmethod
def _resolve_environment_id(session: Session, env_id: Optional[str]) -> Optional[str]:
with belief_scope("_resolve_environment_id"):
raw_value = str(env_id or "").strip()
if not raw_value:
return None
# 1) Direct match by primary key.
by_id = session.query(Environment).filter(Environment.id == raw_value).first()
if by_id:
return str(by_id.id)
# 2) Exact match by name.
by_name = session.query(Environment).filter(Environment.name == raw_value).first()
if by_name:
return str(by_name.id)
# 3) Slug-like match (e.g. "ss-dev" -> "SS DEV").
def normalize_token(value: str) -> str:
lowered = str(value or "").strip().lower()
return re.sub(r"[^a-z0-9]+", "-", lowered).strip("-")
target_token = normalize_token(raw_value)
if not target_token:
return None
for env in session.query(Environment).all():
if normalize_token(env.id) == target_token or normalize_token(env.name) == target_token:
return str(env.id)
return None
# [/DEF:_resolve_environment_id:Function]
# [DEF:__init__:Function]
# @TIER: STANDARD
# @PURPOSE: Initializes the persistence service.
# @PRE: None.
# @POST: Service is ready.
def __init__(self):
with belief_scope("TaskPersistenceService.__init__"):
# We use TasksSessionLocal from database.py
pass
# [/DEF:__init__:Function]
# [DEF:persist_task:Function]
# @TIER: STANDARD
# @PURPOSE: Persists or updates a single task in the database.
# @PRE: isinstance(task, Task)
# @POST: Task record created or updated in database.
# @PARAM: task (Task) - The task object to persist.
# @SIDE_EFFECT: Writes to task_records table in tasks.db
# @DATA_CONTRACT: Input[Task] -> Model[TaskRecord]
# @RELATION: [CALLS] ->[self._resolve_environment_id]
def persist_task(self, task: Task) -> None:
with belief_scope("TaskPersistenceService.persist_task", f"task_id={task.id}"):
session: Session = TasksSessionLocal()
try:
record = session.query(TaskRecord).filter(TaskRecord.id == task.id).first()
if not record:
record = TaskRecord(id=task.id)
session.add(record)
record.type = task.plugin_id
record.status = task.status.value
raw_env_id = task.params.get("environment_id") or task.params.get("source_env_id")
record.environment_id = self._resolve_environment_id(session, raw_env_id)
record.started_at = task.started_at
record.finished_at = task.finished_at
# Ensure params and result are JSON serializable
def json_serializable(obj):
with belief_scope("TaskPersistenceService.json_serializable"):
if isinstance(obj, dict):
return {k: json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [json_serializable(v) for v in obj]
elif isinstance(obj, datetime):
return obj.isoformat()
return obj
record.params = json_serializable(task.params)
record.result = json_serializable(task.result)
# Store logs as JSON, converting datetime to string
record.logs = []
for log in task.logs:
log_dict = log.dict()
if isinstance(log_dict.get('timestamp'), datetime):
log_dict['timestamp'] = log_dict['timestamp'].isoformat()
# Also clean up any datetimes in context
if log_dict.get('context'):
log_dict['context'] = json_serializable(log_dict['context'])
record.logs.append(log_dict)
# Extract error if failed
if task.status == TaskStatus.FAILED:
for log in reversed(task.logs):
if log.level == "ERROR":
record.error = log.message
break
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to persist task {task.id}: {e}")
finally:
session.close()
# [/DEF:persist_task:Function]
# [DEF:persist_tasks:Function]
# @TIER: STANDARD
# @PURPOSE: Persists multiple tasks.
# @PRE: isinstance(tasks, list)
# @POST: All tasks in list are persisted.
# @PARAM: tasks (List[Task]) - The list of tasks to persist.
# @RELATION: [CALLS] ->[self.persist_task]
def persist_tasks(self, tasks: List[Task]) -> None:
with belief_scope("TaskPersistenceService.persist_tasks"):
for task in tasks:
self.persist_task(task)
# [/DEF:persist_tasks:Function]
# [DEF:load_tasks:Function]
# @TIER: STANDARD
# @PURPOSE: Loads tasks from the database.
# @PRE: limit is an integer.
# @POST: Returns list of Task objects.
# @PARAM: limit (int) - Max tasks to load.
# @PARAM: status (Optional[TaskStatus]) - Filter by status.
# @RETURN: List[Task] - The loaded tasks.
# @DATA_CONTRACT: Model[TaskRecord] -> Output[List[Task]]
# @RELATION: [CALLS] ->[self._json_load_if_needed]
# @RELATION: [CALLS] ->[self._parse_datetime]
def load_tasks(self, limit: int = 100, status: Optional[TaskStatus] = None) -> List[Task]:
with belief_scope("TaskPersistenceService.load_tasks"):
session: Session = TasksSessionLocal()
try:
query = session.query(TaskRecord)
if status:
query = query.filter(TaskRecord.status == status.value)
records = query.order_by(TaskRecord.created_at.desc()).limit(limit).all()
loaded_tasks = []
for record in records:
try:
logs = []
logs_payload = self._json_load_if_needed(record.logs)
if isinstance(logs_payload, list):
for log_data in logs_payload:
if not isinstance(log_data, dict):
continue
log_data = dict(log_data)
log_data['timestamp'] = self._parse_datetime(log_data.get('timestamp')) or datetime.utcnow()
logs.append(LogEntry(**log_data))
started_at = self._parse_datetime(record.started_at)
finished_at = self._parse_datetime(record.finished_at)
params = self._json_load_if_needed(record.params)
result = self._json_load_if_needed(record.result)
task = Task(
id=record.id,
plugin_id=record.type,
status=TaskStatus(record.status),
started_at=started_at,
finished_at=finished_at,
params=params if isinstance(params, dict) else {},
result=result,
logs=logs
)
loaded_tasks.append(task)
except Exception as e:
logger.error(f"Failed to reconstruct task {record.id}: {e}")
return loaded_tasks
finally:
session.close()
# [/DEF:load_tasks:Function]
# [DEF:delete_tasks:Function]
# @TIER: STANDARD
# @PURPOSE: Deletes specific tasks from the database.
# @PRE: task_ids is a list of strings.
# @POST: Specified task records deleted from database.
# @PARAM: task_ids (List[str]) - List of task IDs to delete.
# @SIDE_EFFECT: Deletes rows from task_records table.
def delete_tasks(self, task_ids: List[str]) -> None:
if not task_ids:
return
with belief_scope("TaskPersistenceService.delete_tasks"):
session: Session = TasksSessionLocal()
try:
session.query(TaskRecord).filter(TaskRecord.id.in_(task_ids)).delete(synchronize_session=False)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to delete tasks: {e}")
finally:
session.close()
# [/DEF:delete_tasks:Function]
# [/DEF:TaskPersistenceService:Class]
# [DEF:TaskLogPersistenceService:Class]
# @TIER: CRITICAL
# @SEMANTICS: persistence, service, database, log, sqlalchemy
# @PURPOSE: Provides methods to save and query task logs from the task_logs table.
# @RELATION: [DEPENDS_ON] ->[backend.src.models.task.TaskLogRecord]
# @RELATION: [DEPENDS_ON] ->[backend.src.core.database.TasksSessionLocal]
# @RELATION: [USED_BY] ->[backend.src.core.task_manager.manager.TaskManager]
# @INVARIANT: Log entries are batch-inserted for performance.
#
# @TEST_CONTRACT: TaskLogPersistenceService ->
# {
# required_fields: {},
# invariants: [
# "add_logs efficiently saves logs to the database",
# "get_logs retrieves properly filtered LogEntry objects"
# ]
# }
# @TEST_FIXTURE: valid_log_batch -> {"task_id": "123", "logs": [{"level": "INFO", "message": "msg"}]}
# @TEST_EDGE: empty_log_list -> no-op behavior
# @TEST_EDGE: add_logs_db_error -> rollback and log error
# @TEST_INVARIANT: accurate_log_aggregation -> verifies: [valid_log_batch]
class TaskLogPersistenceService:
"""
Service for persisting and querying task logs.
Supports batch inserts, filtering, and statistics.
"""
# [DEF:__init__:Function]
# @TIER: STANDARD
# @PURPOSE: Initializes the TaskLogPersistenceService
# @PRE: config is provided or defaults are used
# @POST: Service is ready for log persistence
def __init__(self, config=None):
pass
# [/DEF:__init__:Function]
# [DEF:add_logs:Function]
# @TIER: STANDARD
# @PURPOSE: Batch insert log entries for a task.
# @PRE: logs is a list of LogEntry objects.
# @POST: All logs inserted into task_logs table.
# @PARAM: task_id (str) - The task ID.
# @PARAM: logs (List[LogEntry]) - Log entries to insert.
# @SIDE_EFFECT: Writes to task_logs table.
# @DATA_CONTRACT: Input[List[LogEntry]] -> Model[TaskLogRecord]
def add_logs(self, task_id: str, logs: List[LogEntry]) -> None:
if not logs:
return
with belief_scope("TaskLogPersistenceService.add_logs", f"task_id={task_id}"):
session: Session = TasksSessionLocal()
try:
for log in logs:
record = TaskLogRecord(
task_id=task_id,
timestamp=log.timestamp,
level=log.level,
source=log.source or "system",
message=log.message,
metadata_json=json.dumps(log.metadata) if log.metadata else None
)
session.add(record)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to add logs for task {task_id}: {e}")
finally:
session.close()
# [/DEF:add_logs:Function]
# [DEF:get_logs:Function]
# @TIER: STANDARD
# @PURPOSE: Query logs for a task with filtering and pagination.
# @PRE: task_id is a valid task ID.
# @POST: Returns list of TaskLog objects matching filters.
# @PARAM: task_id (str) - The task ID.
# @PARAM: log_filter (LogFilter) - Filter parameters.
# @RETURN: List[TaskLog] - Filtered log entries.
# @DATA_CONTRACT: Model[TaskLogRecord] -> Output[List[TaskLog]]
def get_logs(self, task_id: str, log_filter: LogFilter) -> List[TaskLog]:
with belief_scope("TaskLogPersistenceService.get_logs", f"task_id={task_id}"):
session: Session = TasksSessionLocal()
try:
query = session.query(TaskLogRecord).filter(TaskLogRecord.task_id == task_id)
# Apply filters
if log_filter.level:
query = query.filter(TaskLogRecord.level == log_filter.level.upper())
if log_filter.source:
query = query.filter(TaskLogRecord.source == log_filter.source)
if log_filter.search:
search_pattern = f"%{log_filter.search}%"
query = query.filter(TaskLogRecord.message.ilike(search_pattern))
# Order by timestamp ascending (oldest first)
query = query.order_by(TaskLogRecord.timestamp.asc())
# Apply pagination
records = query.offset(log_filter.offset).limit(log_filter.limit).all()
logs = []
for record in records:
metadata = None
if record.metadata_json:
try:
metadata = json.loads(record.metadata_json)
except json.JSONDecodeError:
metadata = None
logs.append(TaskLog(
id=record.id,
task_id=record.task_id,
timestamp=record.timestamp,
level=record.level,
source=record.source,
message=record.message,
metadata=metadata
))
return logs
finally:
session.close()
# [/DEF:get_logs:Function]
# [DEF:get_log_stats:Function]
# @TIER: STANDARD
# @PURPOSE: Get statistics about logs for a task.
# @PRE: task_id is a valid task ID.
# @POST: Returns LogStats with counts by level and source.
# @PARAM: task_id (str) - The task ID.
# @RETURN: LogStats - Statistics about task logs.
# @DATA_CONTRACT: Model[TaskLogRecord] -> Output[LogStats]
def get_log_stats(self, task_id: str) -> LogStats:
with belief_scope("TaskLogPersistenceService.get_log_stats", f"task_id={task_id}"):
session: Session = TasksSessionLocal()
try:
# Get total count
total_count = session.query(TaskLogRecord).filter(
TaskLogRecord.task_id == task_id
).count()
# Get counts by level
from sqlalchemy import func
level_counts = session.query(
TaskLogRecord.level,
func.count(TaskLogRecord.id)
).filter(
TaskLogRecord.task_id == task_id
).group_by(TaskLogRecord.level).all()
by_level = {level: count for level, count in level_counts}
# Get counts by source
source_counts = session.query(
TaskLogRecord.source,
func.count(TaskLogRecord.id)
).filter(
TaskLogRecord.task_id == task_id
).group_by(TaskLogRecord.source).all()
by_source = {source: count for source, count in source_counts}
return LogStats(
total_count=total_count,
by_level=by_level,
by_source=by_source
)
finally:
session.close()
# [/DEF:get_log_stats:Function]
# [DEF:get_sources:Function]
# @TIER: STANDARD
# @PURPOSE: Get unique sources for a task's logs.
# @PRE: task_id is a valid task ID.
# @POST: Returns list of unique source strings.
# @PARAM: task_id (str) - The task ID.
# @RETURN: List[str] - Unique source names.
# @DATA_CONTRACT: Model[TaskLogRecord] -> Output[List[str]]
def get_sources(self, task_id: str) -> List[str]:
with belief_scope("TaskLogPersistenceService.get_sources", f"task_id={task_id}"):
session: Session = TasksSessionLocal()
try:
from sqlalchemy import distinct
sources = session.query(distinct(TaskLogRecord.source)).filter(
TaskLogRecord.task_id == task_id
).all()
return [s[0] for s in sources]
finally:
session.close()
# [/DEF:get_sources:Function]
# [DEF:delete_logs_for_task:Function]
# @TIER: STANDARD
# @PURPOSE: Delete all logs for a specific task.
# @PRE: task_id is a valid task ID.
# @POST: All logs for the task are deleted.
# @PARAM: task_id (str) - The task ID.
# @SIDE_EFFECT: Deletes from task_logs table.
def delete_logs_for_task(self, task_id: str) -> None:
with belief_scope("TaskLogPersistenceService.delete_logs_for_task", f"task_id={task_id}"):
session: Session = TasksSessionLocal()
try:
session.query(TaskLogRecord).filter(
TaskLogRecord.task_id == task_id
).delete(synchronize_session=False)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to delete logs for task {task_id}: {e}")
finally:
session.close()
# [/DEF:delete_logs_for_task:Function]
# [DEF:delete_logs_for_tasks:Function]
# @TIER: STANDARD
# @PURPOSE: Delete all logs for multiple tasks.
# @PRE: task_ids is a list of task IDs.
# @POST: All logs for the tasks are deleted.
# @PARAM: task_ids (List[str]) - List of task IDs.
# @SIDE_EFFECT: Deletes rows from task_logs table.
def delete_logs_for_tasks(self, task_ids: List[str]) -> None:
if not task_ids:
return
with belief_scope("TaskLogPersistenceService.delete_logs_for_tasks"):
session: Session = TasksSessionLocal()
try:
session.query(TaskLogRecord).filter(
TaskLogRecord.task_id.in_(task_ids)
).delete(synchronize_session=False)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to delete logs for tasks: {e}")
finally:
session.close()
# [/DEF:delete_logs_for_tasks:Function]
# [/DEF:TaskLogPersistenceService:Class]
# [/DEF:TaskPersistenceModule:Module]