Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions invokeai/app/api/routers/session_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
QUEUE_ORDER_BY,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
Expand All @@ -18,14 +18,15 @@
DeleteByDestinationResult,
EnqueueBatchResult,
FieldIdentifier,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection

session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])

Expand Down Expand Up @@ -69,54 +70,75 @@ async def enqueue_batch(


@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItem]},
200: {"model": list[SessionQueueItem]},
},
)
async def list_queue_items(
async def list_all_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
limit: int = Query(default=50, description="The number of items to fetch"),
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""

) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")


@session_queue_router.get(
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
"/{queue_id}/item_ids",
operation_id="get_queue_itemIds",
responses={
200: {"model": list[SessionQueueItem]},
200: {"model": ItemIdsResult},
},
)
async def list_all_queue_items(
async def get_queue_item_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
order_by: QUEUE_ORDER_BY = Query(default="created_at", description="The sort field"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
queue_id=queue_id, order_by=order_by, order_dir=order_dir
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")


@session_queue_router.post(
"/{queue_id}/items_by_ids",
operation_id="get_queue_items_by_item_ids",
responses={200: {"model": list[SessionQueueItem]}},
)
async def get_queue_items_by_item_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_ids: list[int] = Body(
embed=True, description="Object containing list of queue item ids to fetch queue items for"
),
) -> list[SessionQueueItem]:
"""Gets queue items for the specified queue item ids. Maintains order of item ids."""
try:
session_queue_service = ApiDependencies.invoker.services.session_queue

# Fetch queue items preserving the order of requested item ids
queue_items: list[SessionQueueItem] = []
for item_id in item_ids:
try:
queue_item = session_queue_service.get_queue_item(item_id)
queue_items.append(queue_item)
except Exception:
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
continue

return queue_items
except Exception:
raise HTTPException(status_code=500, detail="Failed to get queue items")


@session_queue_router.put(
Expand Down
8 changes: 4 additions & 4 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
created_at: str = Field(description="The timestamp when the queue item was created")
updated_at: str = Field(description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
Expand All @@ -258,8 +258,8 @@ def build(
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
created_at=str(queue_item.created_at),
updated_at=str(queue_item.updated_at),
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
Expand Down
24 changes: 11 additions & 13 deletions invokeai/app/services/session_queue/session_queue_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Coroutine, Optional

from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
QUEUE_ORDER_BY,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
Expand All @@ -15,14 +15,15 @@
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection


class SessionQueueBase(ABC):
Expand Down Expand Up @@ -136,25 +137,22 @@ def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResu
pass

@abstractmethod
def list_queue_items(
def list_all_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items"""
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
pass

@abstractmethod
def list_all_queue_items(
def get_queue_item_ids(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
order_by: QUEUE_ORDER_BY = "created_at",
order_dir: SQLiteDirection = SQLiteDirection.Descending,
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
pass

@abstractmethod
Expand Down
9 changes: 9 additions & 0 deletions invokeai/app/services/session_queue/session_queue_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,17 @@ def validate_graph(cls, v: Graph):

DEFAULT_QUEUE_ID = "default"

QUEUE_ORDER_BY = Literal["created_at", "completed_at"]
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]


class ItemIdsResult(BaseModel):
"""Response containing ordered item ids with metadata for optimistic updates."""

item_ids: list[int] = Field(description="Ordered list of item ids")
total_count: int = Field(description="Total number of queue items matching the query")


NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])


Expand Down
78 changes: 24 additions & 54 deletions invokeai/app/services/session_queue/session_queue_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from invokeai.app.services.session_queue.session_queue_common import (
DEFAULT_QUEUE_ID,
QUEUE_ITEM_STATUS,
QUEUE_ORDER_BY,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
Expand All @@ -22,6 +23,7 @@
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
Expand All @@ -33,7 +35,7 @@
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase


Expand Down Expand Up @@ -587,59 +589,6 @@ def set_queue_item_session(self, item_id: int, session: GraphExecutionState) ->
)
return self.get_queue_item(item_id)

def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]

if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)

if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)

if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])

query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)

def list_all_queue_items(
self,
queue_id: str,
Expand Down Expand Up @@ -671,6 +620,27 @@ def list_all_queue_items(
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items

def get_queue_item_ids(
self,
queue_id: str,
order_by: QUEUE_ORDER_BY = "created_at",
order_dir: SQLiteDirection = SQLiteDirection.Descending,
) -> ItemIdsResult:
with self._db.transaction() as cursor_:
query = f"""--sql
SELECT item_id
FROM session_queue
WHERE queue_id = ?
ORDER BY {order_by} {order_dir.value}
"""
query_params = [queue_id]

cursor_.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor_.fetchall())
item_ids = [row[0] for row in result]

return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))

def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
with self._db.transaction() as cursor:
cursor.execute(
Expand Down
9 changes: 7 additions & 2 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@
"workflows": "Workflows",
"other": "Other",
"gallery": "Gallery",
"batchFieldValues": "Batch Field Values",
"item": "Item",
"session": "Session",
"notReady": "Unable to Queue",
Expand All @@ -324,7 +323,13 @@
"iterations_other": "Iterations",
"generations_one": "Generation",
"generations_other": "Generations",
"batchSize": "Batch Size"
"batchSize": "Batch Size",
"createdAt": "Created At",
"completedAt": "Completed At",
"sortColumn": "Sort Column",
"sortBy": "Sort by {{column}}",
"sortOrderAscending": "Ascending",
"sortOrderDescending": "Descending"
},
"invocationCache": {
"invocationCache": "Invocation Cache",
Expand Down
Loading
Loading