From ae8bb9a9a767ea1da5bf7e5f59aa72ca6b0759da Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 17 Sep 2025 11:40:11 -0400 Subject: [PATCH 1/2] enqueue refactor --- .../frontend/web/src/app/logging/logger.ts | 1 + .../workflow/PublishWorkflowPanelContent.tsx | 8 +- .../features/queue/hooks/useEnqueueCanvas.ts | 200 +++++++----------- .../queue/hooks/useEnqueueGenerate.ts | 188 +++++++--------- .../queue/hooks/useEnqueueUpscaling.ts | 80 +++---- .../features/queue/hooks/useEnqueueVideo.ts | 182 ++++++++-------- .../queue/hooks/useEnqueueWorkflows.ts | 95 ++++----- .../queue/hooks/utils/executeEnqueue.ts | 70 ++++++ .../queue/hooks/utils/graphBuilders.ts | 34 +++ 9 files changed, 434 insertions(+), 424 deletions(-) create mode 100644 invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.ts create mode 100644 invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.ts diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index 1f753f97bb7..67f5f584086 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -27,6 +27,7 @@ export const zLogNamespace = z.enum([ 'queue', 'workflows', 'video', + 'enqueue', ]); export type LogNamespace = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent.tsx index 1f90716819b..3c5e975e368 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent.tsx @@ -253,11 +253,11 @@ const PublishWorkflowButton = memo(() => { ), duration: null, }); - assert(result.value.enqueueResult.batch.batch_id); - assert(result.value.batchConfig.validation_run_data); + assert(result.value?.enqueueResult.batch.batch_id); + assert(result.value?.batchConfig.validation_run_data); $validationRunData.set({ - batchId: result.value.enqueueResult.batch.batch_id, - workflowId: result.value.batchConfig.validation_run_data.workflow_id, + batchId: result.value?.enqueueResult.batch.batch_id, + workflowId: result.value?.batchConfig.validation_run_data.workflow_id, }); log.debug(parseify(result.value), 'Enqueued batch'); } diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts index 9d5a589f056..009887b12be 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts @@ -1,154 +1,114 @@ import type { AlertStatus } from '@invoke-ai/ui-library'; import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; -import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; import { withResult, withResultAsync } from 'common/util/result'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; -import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; -import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; -import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph'; -import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph'; -import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph'; -import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph'; -import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph'; -import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; -import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; -import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { serializeError } from 'serialize-error'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; -import { assert, AssertionError } from 'tsafe'; +import { AssertionError } from 'tsafe'; + +import type { EnqueueBatchArg } from './utils/executeEnqueue'; +import { executeEnqueue } from './utils/executeEnqueue'; +import { buildGraphForBase } from './utils/graphBuilders'; const log = logger('generation'); export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas'); -const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => { - const { dispatch, getState } = store; - - dispatch(enqueueRequestedCanvas()); - - const state = getState(); - - const destination = selectCanvasDestination(state); - - const model = state.params.model; - if (!model) { - log.error('No model found in state'); - return; - } - - const base = model.base; - - const buildGraphResult = await withResultAsync(async () => { - const generationMode = await canvasManager.compositor.getGenerationMode(); - const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager }; - - switch (base) { - case 'sdxl': - return await buildSDXLGraph(graphBuilderArg); - case 'sd-1': - case `sd-2`: - return await buildSD1Graph(graphBuilderArg); - case `sd-3`: - return await buildSD3Graph(graphBuilderArg); - case `flux`: - return await buildFLUXGraph(graphBuilderArg); - case 'cogview4': - return await buildCogView4Graph(graphBuilderArg); - case 'imagen3': - return buildImagen3Graph(graphBuilderArg); - case 'imagen4': - return buildImagen4Graph(graphBuilderArg); - case 'chatgpt-4o': - return await buildChatGPT4oGraph(graphBuilderArg); - case 'flux-kontext': - return buildFluxKontextGraph(graphBuilderArg); - case 'gemini-2.5': - return buildGemini2_5Graph(graphBuilderArg); - default: - assert(false, `No graph builders for base ${base}`); - } - }); - - if (buildGraphResult.isErr()) { - let title = 'Failed to build graph'; - let status: AlertStatus = 'error'; - let description: string | null = null; - if (buildGraphResult.error instanceof AssertionError) { - description = extractMessageFromAssertionError(buildGraphResult.error); - } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { - title = 'Unsupported generation mode'; - description = buildGraphResult.error.message; - status = 'warning'; - } - const error = serializeError(buildGraphResult.error); - log.error({ error }, 'Failed to build graph'); - toast({ - status, - title, - description, - }); - return; - } - - const { g, seed, positivePrompt } = buildGraphResult.value; - - const prepareBatchResult = withResult(() => - prepareLinearUIBatch({ - state, - g, - base, - prepend, - seedNode: seed, - positivePromptNode: positivePrompt, - origin: 'canvas', - destination, - }) - ); - - if (prepareBatchResult.isErr()) { - log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); - return; - } - - const batchConfig = prepareBatchResult.value; - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - ...enqueueMutationFixedCacheKeyOptions, - track: false, - }) - ); - - const enqueueResult = await req.unwrap(); - - // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); - - return { batchConfig, enqueueResult }; +type CanvasBuildResult = { + batchConfig: EnqueueBatchArg; }; export const useEnqueueCanvas = () => { const store = useAppStore(); const canvasManager = useCanvasManagerSafe(); + const enqueue = useCallback( (prepend: boolean) => { if (!canvasManager) { log.error('Canvas manager is not available'); - return; + return null; } - return enqueueCanvas(store, canvasManager, prepend); + + return executeEnqueue({ + store, + options: { prepend }, + requestedAction: enqueueRequestedCanvas, + log, + build: async ({ store: innerStore, options }) => { + const state = innerStore.getState(); + + const destination = selectCanvasDestination(state); + const model = state.params.model; + if (!model) { + log.error('No model found in state'); + return null; + } + + const generationMode = await canvasManager.compositor.getGenerationMode(); + const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager }; + + const buildGraphResult = await withResultAsync( + async () => await buildGraphForBase(model.base, graphBuilderArg) + ); + + if (buildGraphResult.isErr()) { + let title = 'Failed to build graph'; + let status: AlertStatus = 'error'; + let description: string | null = null; + if (buildGraphResult.error instanceof AssertionError) { + description = extractMessageFromAssertionError(buildGraphResult.error); + } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { + title = 'Unsupported generation mode'; + description = buildGraphResult.error.message; + status = 'warning'; + } + const error = serializeError(buildGraphResult.error); + log.error({ error }, 'Failed to build graph'); + toast({ status, title, description }); + return null; + } + + const { g, seed, positivePrompt } = buildGraphResult.value; + + const prepareBatchResult = withResult(() => + prepareLinearUIBatch({ + state, + g, + base: model.base, + prepend: options.prepend, + seedNode: seed, + positivePromptNode: positivePrompt, + origin: 'canvas', + destination, + }) + ); + + if (prepareBatchResult.isErr()) { + log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); + return null; + } + + return { + batchConfig: prepareBatchResult.value, + } satisfies CanvasBuildResult; + }, + prepareBatch: ({ buildResult }) => buildResult.batchConfig, + onSuccess: ({ store: innerStore }) => { + const state = innerStore.getState(); + innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + }, + }); }, [canvasManager, store] ); + return enqueue; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts index 1529a87cff9..5f0a5cee7db 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts @@ -1,143 +1,103 @@ import type { AlertStatus } from '@invoke-ai/ui-library'; import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; -import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; import { withResult, withResultAsync } from 'common/util/result'; import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; -import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; -import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; -import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph'; -import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph'; -import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph'; -import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph'; -import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph'; -import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; -import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; -import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; import type { GraphBuilderArg } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { serializeError } from 'serialize-error'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; -import { assert, AssertionError } from 'tsafe'; +import { AssertionError } from 'tsafe'; + +import type { EnqueueBatchArg } from './utils/executeEnqueue'; +import { executeEnqueue } from './utils/executeEnqueue'; +import { buildGraphForBase } from './utils/graphBuilders'; const log = logger('generation'); export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate'); -const enqueueGenerate = async (store: AppStore, prepend: boolean) => { - const { dispatch, getState } = store; - - dispatch(enqueueRequestedGenerate()); - - const state = getState(); - - const model = state.params.model; - if (!model) { - log.error('No model found in state'); - return; - } - const base = model.base; - - const buildGraphResult = await withResultAsync(async () => { - const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null }; - - switch (base) { - case 'sdxl': - return await buildSDXLGraph(graphBuilderArg); - case 'sd-1': - case `sd-2`: - return await buildSD1Graph(graphBuilderArg); - case `sd-3`: - return await buildSD3Graph(graphBuilderArg); - case `flux`: - return await buildFLUXGraph(graphBuilderArg); - case 'cogview4': - return await buildCogView4Graph(graphBuilderArg); - case 'imagen3': - return buildImagen3Graph(graphBuilderArg); - case 'imagen4': - return buildImagen4Graph(graphBuilderArg); - case 'chatgpt-4o': - return await buildChatGPT4oGraph(graphBuilderArg); - case 'flux-kontext': - return buildFluxKontextGraph(graphBuilderArg); - case 'gemini-2.5': - return buildGemini2_5Graph(graphBuilderArg); - default: - assert(false, `No graph builders for base ${base}`); - } - }); - - if (buildGraphResult.isErr()) { - let title = 'Failed to build graph'; - let status: AlertStatus = 'error'; - let description: string | null = null; - if (buildGraphResult.error instanceof AssertionError) { - description = extractMessageFromAssertionError(buildGraphResult.error); - } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { - title = 'Unsupported generation mode'; - description = buildGraphResult.error.message; - status = 'warning'; - } - const error = serializeError(buildGraphResult.error); - log.error({ error }, 'Failed to build graph'); - toast({ - status, - title, - description, - }); - return; - } - - const { g, seed, positivePrompt } = buildGraphResult.value; - - const prepareBatchResult = withResult(() => - prepareLinearUIBatch({ - state, - g, - base, - prepend, - seedNode: seed, - positivePromptNode: positivePrompt, - origin: 'generate', - destination: 'generate', - }) - ); - - if (prepareBatchResult.isErr()) { - log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); - return; - } - - const batchConfig = prepareBatchResult.value; - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - ...enqueueMutationFixedCacheKeyOptions, - track: false, - }) - ); - - const enqueueResult = await req.unwrap(); - - // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); - - return { batchConfig, enqueueResult }; +type GenerateBuildResult = { + batchConfig: EnqueueBatchArg; }; export const useEnqueueGenerate = () => { const store = useAppStore(); + const enqueue = useCallback( (prepend: boolean) => { - return enqueueGenerate(store, prepend); + return executeEnqueue({ + store, + options: { prepend }, + requestedAction: enqueueRequestedGenerate, + log, + build: async ({ store: innerStore, options }) => { + const state = innerStore.getState(); + const model = state.params.model; + if (!model) { + log.error('No model found in state'); + return null; + } + + const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null }; + const buildGraphResult = await withResultAsync( + async () => await buildGraphForBase(model.base, graphBuilderArg) + ); + + if (buildGraphResult.isErr()) { + let title = 'Failed to build graph'; + let status: AlertStatus = 'error'; + let description: string | null = null; + if (buildGraphResult.error instanceof AssertionError) { + description = extractMessageFromAssertionError(buildGraphResult.error); + } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { + title = 'Unsupported generation mode'; + description = buildGraphResult.error.message; + status = 'warning'; + } + const error = serializeError(buildGraphResult.error); + log.error({ error }, 'Failed to build graph'); + toast({ status, title, description }); + return null; + } + + const { g, seed, positivePrompt } = buildGraphResult.value; + + const prepareBatchResult = withResult(() => + prepareLinearUIBatch({ + state, + g, + base: model.base, + prepend: options.prepend, + seedNode: seed, + positivePromptNode: positivePrompt, + origin: 'generate', + destination: 'generate', + }) + ); + + if (prepareBatchResult.isErr()) { + log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); + return null; + } + + return { + batchConfig: prepareBatchResult.value, + } satisfies GenerateBuildResult; + }, + prepareBatch: ({ buildResult }) => buildResult.batchConfig, + onSuccess: ({ store: innerStore }) => { + const state = innerStore.getState(); + innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + }, + }); }, [store] ); + return enqueue; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts index 01f278d98db..19983e9bcff 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts @@ -1,62 +1,64 @@ import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; -import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph'; import { useCallback } from 'react'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; + +import type { EnqueueBatchArg } from './utils/executeEnqueue'; +import { executeEnqueue } from './utils/executeEnqueue'; export const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling'); const log = logger('generation'); -const enqueueUpscaling = async (store: AppStore, prepend: boolean) => { - const { dispatch, getState } = store; - - dispatch(enqueueRequestedUpscaling()); - - const state = getState(); - - const model = state.params.model; - if (!model) { - log.error('No model found in state'); - return; - } - const base = model.base; - - const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state); - - const batchConfig = prepareLinearUIBatch({ - state, - g, - base, - prepend, - seedNode: seed, - positivePromptNode: positivePrompt, - origin: 'upscaling', - destination: 'gallery', - }); - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false }) - ); - const enqueueResult = await req.unwrap(); - - // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); - - return { batchConfig, enqueueResult }; +type UpscaleBuildResult = { + batchConfig: EnqueueBatchArg; }; export const useEnqueueUpscaling = () => { const store = useAppStore(); + const enqueue = useCallback( (prepend: boolean) => { - return enqueueUpscaling(store, prepend); + return executeEnqueue({ + store, + options: { prepend }, + requestedAction: enqueueRequestedUpscaling, + log, + build: async ({ store: innerStore, options }) => { + const state = innerStore.getState(); + const model = state.params.model; + if (!model) { + log.error('No model found in state'); + return null; + } + + const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state); + + const batchConfig = prepareLinearUIBatch({ + state, + g, + base: model.base, + prepend: options.prepend, + seedNode: seed, + positivePromptNode: positivePrompt, + origin: 'upscaling', + destination: 'gallery', + }); + + return { batchConfig } satisfies UpscaleBuildResult; + }, + prepareBatch: ({ buildResult }) => buildResult.batchConfig, + onSuccess: ({ store: innerStore }) => { + const state = innerStore.getState(); + innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + }, + }); }, [store] ); + return enqueue; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueVideo.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueVideo.ts index 183026c3632..cca855fc770 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueVideo.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueVideo.ts @@ -1,7 +1,6 @@ import type { AlertStatus } from '@invoke-ai/ui-library'; import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; -import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; import { withResult, withResultAsync } from 'common/util/result'; @@ -14,114 +13,107 @@ import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types' import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { serializeError } from 'serialize-error'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; -import { assert, AssertionError } from 'tsafe'; +import { AssertionError } from 'tsafe'; + +import type { EnqueueBatchArg } from './utils/executeEnqueue'; +import { executeEnqueue } from './utils/executeEnqueue'; const log = logger('generation'); export const enqueueRequestedVideos = createAction('app/enqueueRequestedVideos'); -const enqueueVideo = async (store: AppStore, prepend: boolean) => { - const { dispatch, getState } = store; - - dispatch(enqueueRequestedVideos()); - - const state = getState(); - - const model = state.video.videoModel; - if (!model) { - log.error('No model found in state'); - return; - } - const base = model.base; - - const buildGraphResult = await withResultAsync(async () => { - const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null }; - - switch (base) { - case 'veo3': - return await buildVeo3VideoGraph(graphBuilderArg); - case 'runway': - return await buildRunwayVideoGraph(graphBuilderArg); - default: - assert(false, `No graph builders for base ${base}`); - } - }); - - if (buildGraphResult.isErr()) { - let title = 'Failed to build graph'; - let status: AlertStatus = 'error'; - let description: string | null = null; - if (buildGraphResult.error instanceof AssertionError) { - description = extractMessageFromAssertionError(buildGraphResult.error); - } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { - title = 'Unsupported generation mode'; - description = buildGraphResult.error.message; - status = 'warning'; - } - const error = serializeError(buildGraphResult.error); - log.error({ error }, 'Failed to build graph'); - toast({ - status, - title, - description, - }); - return; - } - - const { g, positivePrompt, seed } = buildGraphResult.value; - - const prepareBatchResult = withResult(() => - prepareLinearUIBatch({ - state, - g, - base, - prepend, - seedNode: seed, - positivePromptNode: positivePrompt, - origin: 'videos', - destination: 'gallery', - }) - ); +type VideoBuildResult = { + batchConfig: EnqueueBatchArg; +}; - if (prepareBatchResult.isErr()) { - log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); - return; +const getVideoGraphBuilder = (base: string) => { + switch (base) { + case 'veo3': + return buildVeo3VideoGraph; + case 'runway': + return buildRunwayVideoGraph; + default: + return null; } - - const batchConfig = prepareBatchResult.value; - - // const batchConfig = { - // prepend, - // batch: { - // graph: g.getGraph(), - // runs: 1, - // origin, - // destination, - // }, - // }; - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - ...enqueueMutationFixedCacheKeyOptions, - track: false, - }) - ); - - const enqueueResult = await req.unwrap(); - - // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); - - return { batchConfig, enqueueResult }; }; export const useEnqueueVideo = () => { const store = useAppStore(); + const enqueue = useCallback( (prepend: boolean) => { - return enqueueVideo(store, prepend); + return executeEnqueue({ + store, + options: { prepend }, + requestedAction: enqueueRequestedVideos, + log, + build: async ({ store: innerStore, options }) => { + const state = innerStore.getState(); + + const model = state.video.videoModel; + if (!model) { + log.error('No model found in state'); + return null; + } + + const builder = getVideoGraphBuilder(model.base); + if (!builder) { + log.error({ base: model.base }, 'No graph builders for base'); + return null; + } + + const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null }; + const buildGraphResult = await withResultAsync(async () => await builder(graphBuilderArg)); + + if (buildGraphResult.isErr()) { + let title = 'Failed to build graph'; + let status: AlertStatus = 'error'; + let description: string | null = null; + if (buildGraphResult.error instanceof AssertionError) { + description = extractMessageFromAssertionError(buildGraphResult.error); + } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { + title = 'Unsupported generation mode'; + description = buildGraphResult.error.message; + status = 'warning'; + } + const error = serializeError(buildGraphResult.error); + log.error({ error }, 'Failed to build graph'); + toast({ status, title, description }); + return null; + } + + const { g, positivePrompt, seed } = buildGraphResult.value; + + const prepareBatchResult = withResult(() => + prepareLinearUIBatch({ + state, + g, + base: model.base, + prepend: options.prepend, + seedNode: seed, + positivePromptNode: positivePrompt, + origin: 'videos', + destination: 'gallery', + }) + ); + + if (prepareBatchResult.isErr()) { + log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); + return null; + } + + return { + batchConfig: prepareBatchResult.value, + } satisfies VideoBuildResult; + }, + prepareBatch: ({ buildResult }) => buildResult.batchConfig, + onSuccess: ({ store: innerStore }) => { + const state = innerStore.getState(); + innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + }, + }); }, [store] ); + return enqueue; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts index 85272f4768a..2b528f7d769 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts @@ -1,5 +1,5 @@ import { createAction } from '@reduxjs/toolkit'; -import type { AppDispatch, AppStore, RootState } from 'app/store/store'; +import type { AppDispatch, RootState } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { groupBy } from 'es-toolkit/compat'; import { @@ -15,10 +15,11 @@ import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph'; import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue'; import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow'; import { useCallback } from 'react'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; import type { Batch, EnqueueBatchArg, S } from 'services/api/types'; import { assert } from 'tsafe'; +import { executeEnqueue } from './utils/executeEnqueue'; + export const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows'); const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise => { @@ -119,60 +120,50 @@ const getValidationRunData = (state: RootState, templates: Templates): S['Valida }; }; -const enqueueWorkflows = async ( - store: AppStore, - templates: Templates, - prepend: boolean, - isApiValidationRun: boolean -) => { - const { dispatch, getState } = store; - - dispatch(enqueueRequestedWorkflows()); - const state = getState(); - const nodesState = selectNodesSlice(state); - const graph = buildNodesGraph(state, templates); - const workflow = buildWorkflowWithValidation(nodesState); - - if (workflow) { - // embedded workflows don't have an id - delete workflow.id; - } - - const runs = state.params.iterations; - const data = await getBatchDataForWorkflowGeneration(state, dispatch); - - const batchConfig: EnqueueBatchArg = { - batch: { - graph, - workflow, - runs, - origin: 'workflows', - destination: 'gallery', - data, - }, - prepend, - }; - - if (isApiValidationRun) { - batchConfig.validation_run_data = getValidationRunData(state, templates); - - // If the batch is an API validation run, we only want to run it once - batchConfig.batch.runs = 1; - } - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false }) - ); - - const enqueueResult = await req.unwrap(); - return { batchConfig, enqueueResult }; -}; - export const useEnqueueWorkflows = () => { const store = useAppStore(); const enqueue = useCallback( (prepend: boolean, isApiValidationRun: boolean) => { - return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun); + return executeEnqueue({ + store, + options: { prepend, isApiValidationRun }, + requestedAction: enqueueRequestedWorkflows, + build: async ({ store: innerStore, options }) => { + const { dispatch, getState } = innerStore; + const state = getState(); + const nodesState = selectNodesSlice(state); + const templates = $templates.get(); + const graph = buildNodesGraph(state, templates); + const workflow = buildWorkflowWithValidation(nodesState); + + if (workflow) { + // embedded workflows don't have an id + delete workflow.id; + } + + const data = await getBatchDataForWorkflowGeneration(state, dispatch); + + const batchConfig: EnqueueBatchArg = { + batch: { + graph, + workflow, + runs: state.params.iterations, + origin: 'workflows', + destination: 'gallery', + data, + }, + prepend: options.prepend, + }; + + if (options.isApiValidationRun) { + batchConfig.validation_run_data = getValidationRunData(state, templates); + batchConfig.batch.runs = 1; + } + + return { batchConfig } satisfies { batchConfig: EnqueueBatchArg }; + }, + prepareBatch: ({ buildResult }) => buildResult.batchConfig, + }); }, [store] ); diff --git a/invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.ts b/invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.ts new file mode 100644 index 00000000000..b71f4396a59 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.ts @@ -0,0 +1,70 @@ +import type { ActionCreatorWithoutPayload } from '@reduxjs/toolkit'; +import { logger } from 'app/logging/logger'; +import type { AppStore } from 'app/store/store'; +import { serializeError } from 'serialize-error'; +import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; +import type { paths } from 'services/api/schema'; + +export type EnqueueBatchArg = + paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json']; +export type EnqueueBatchResponse = + paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['responses']['201']['content']['application/json']; + +export type EnqueueOptionsBase = { prepend: boolean }; + +interface ExecuteEnqueueConfig { + store: AppStore; + options: TOptions; + requestedAction: ActionCreatorWithoutPayload; + build: (context: { store: AppStore; options: TOptions }) => Promise; + prepareBatch: (context: { store: AppStore; options: TOptions; buildResult: TBuildResult }) => EnqueueBatchArg; + onSuccess?: (context: { + store: AppStore; + options: TOptions; + buildResult: TBuildResult; + batch: EnqueueBatchArg; + response: EnqueueBatchResponse; + }) => void; + onError?: (context: { store: AppStore; options: TOptions; error: unknown }) => void; + log?: ReturnType; +} + +export const executeEnqueue = async ({ + store, + options, + requestedAction, + build, + prepareBatch, + onSuccess, + onError, + log = logger('enqueue'), +}: ExecuteEnqueueConfig) => { + const { dispatch } = store; + dispatch(requestedAction()); + + try { + const buildResult = await build({ store, options }); + if (!buildResult) { + return null; + } + + const batchConfig = prepareBatch({ store, options, buildResult }); + + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { + ...enqueueMutationFixedCacheKeyOptions, + track: false, + }) + ); + + const enqueueResult = await req.unwrap(); + + onSuccess?.({ store, options, buildResult, batch: batchConfig, response: enqueueResult }); + + return { batchConfig, enqueueResult }; + } catch (error) { + log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch'); + onError?.({ store, options, error }); + return null; + } +}; diff --git a/invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.ts b/invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.ts new file mode 100644 index 00000000000..f5bf9fb3df8 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.ts @@ -0,0 +1,34 @@ +import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; +import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; +import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph'; +import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph'; +import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph'; +import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph'; +import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph'; +import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; +import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; +import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; +import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types'; +import { assert } from 'tsafe'; + +type GraphBuilderFn = (arg: GraphBuilderArg) => GraphBuilderReturn | Promise; + +const graphBuilderMap: Record = { + sdxl: buildSDXLGraph, + 'sd-1': buildSD1Graph, + 'sd-2': buildSD1Graph, + 'sd-3': buildSD3Graph, + flux: buildFLUXGraph, + 'flux-kontext': buildFluxKontextGraph, + cogview4: buildCogView4Graph, + imagen3: buildImagen3Graph, + imagen4: buildImagen4Graph, + 'chatgpt-4o': buildChatGPT4oGraph, + 'gemini-2.5': buildGemini2_5Graph, +}; + +export const buildGraphForBase = async (base: string, arg: GraphBuilderArg) => { + const builder = graphBuilderMap[base]; + assert(builder, `No graph builders for base ${base}`); + return await builder(arg); +}; From b5d747132686af815b5ed491c0869364c47e39a5 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 17 Sep 2025 13:53:08 -0400 Subject: [PATCH 2/2] add tests and readme --- .../frontend/web/src/features/queue/README.md | 51 +++++++++ .../queue/hooks/utils/executeEnqueue.test.ts | 107 ++++++++++++++++++ .../queue/hooks/utils/graphBuilders.test.ts | 81 +++++++++++++ 3 files changed, 239 insertions(+) create mode 100644 invokeai/frontend/web/src/features/queue/README.md create mode 100644 invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.test.ts create mode 100644 invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.test.ts diff --git a/invokeai/frontend/web/src/features/queue/README.md b/invokeai/frontend/web/src/features/queue/README.md new file mode 100644 index 00000000000..0bc9bb0b908 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/README.md @@ -0,0 +1,51 @@ +# Queue Enqueue Patterns + +This directory contains the hooks and utilities that translate UI actions into queue batches. The flow is intentionally +modular so adding a new enqueue type (e.g. a new generation mode) follows a predictable recipe. + +## Key building blocks + +- `hooks/useEnqueue*.ts` – Feature-specific hooks (generate, canvas, upscaling, video, workflows). Each hook wires local + state to the shared enqueue utilities. +- `hooks/utils/graphBuilders.ts` – Maps base models (sdxl, flux, etc.) to their graph builder functions and normalizes + synchronous vs. asynchronous builders. +- `hooks/utils/executeEnqueue.ts` – Orchestrates the enqueue lifecycle: + 1. dispatch the `enqueueRequested*` action + 2. build the graph/batch data + 3. call `queueApi.endpoints.enqueueBatch` + 4. run success/error callbacks + +## Adding a new enqueue type + +1. **Implement the graph builder (if needed).** + - Create the graph construction logic in `features/nodes/util/graph/generation/...` so it returns a + `GraphBuilderReturn`. + - If the builder reuses existing primitives, consider wiring it into `graphBuilders.ts` by extending the `graphBuilderMap`. + +2. **Create the enqueue hook.** + - Add `useEnqueue.ts` mirroring the existing hooks. Import `executeEnqueue` and supply feature-specific + `build`, `prepareBatch`, and `onSuccess` callbacks. + - If the feature depends on a new base model, add it to `graphBuilders.ts`. + +3. **Register the tab in `useInvoke`.** + - `useInvoke.ts` looks up handlers based on the active tab. Import your new hook and call it inside the `switch` + (or future registry) so the UI can enqueue from the feature. + +4. **Add Redux action (optional).** + - Most enqueue hooks dispatch a `enqueueRequested*` action for devtools visibility. Create one with `createAction` if + you want similar tracing. + +5. **Cover with tests.** + - Unit-test feature-specific behavior (graph selection, batch tweaks). The shared helpers already have coverage in + `hooks/utils/`. + +## Tips + +- Keep `build` lean: fetch state, compose graph/batch data, and return `null` when prerequisites are missing. The shared + helper will skip enqueueing and your `onError` will handle logging. +- Use the shared `prepareLinearUIBatch` for single-graph UI workflows. For advanced cases (multi-run batches, workflow + validation runs), supply a custom `prepareBatch` function. +- Prefer updating `graphBuilders.ts` when adding a new base model so every image-based enqueue automatically benefits. + +With this structure, the main task when introducing a new enqueue type is describing how to build its graph and how to +massage the batch payload—everything else (dispatching, API calls, history updates) is handled by the utilities. diff --git a/invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.test.ts b/invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.test.ts new file mode 100644 index 00000000000..21cff83b7d5 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/utils/executeEnqueue.test.ts @@ -0,0 +1,107 @@ +import { createAction } from '@reduxjs/toolkit'; +import type { AppStore, RootState } from 'app/store/store'; +import type { EnqueueBatchArg, EnqueueBatchResponse } from './executeEnqueue'; +import { executeEnqueue } from './executeEnqueue'; +import { describe, expect, it, vi } from 'vitest'; + +const createTestStore = () => { + const state = {} as RootState; + const dispatch = vi.fn<(action: unknown) => unknown>((action) => { + if (typeof action === 'object' && action !== null && 'type' in action) { + return undefined; + } + const unwrap = vi.fn<() => Promise>().mockResolvedValue({ + batch_id: 'batch-1', + item_ids: ['item-1'], + } as EnqueueBatchResponse); + return { unwrap }; + }); + const getState = vi.fn(() => state); + return { dispatch, getState } as unknown as AppStore; +}; + +const createBatchArg = (prepend: boolean): EnqueueBatchArg => ({ + prepend, + batch: { + graph: {} as EnqueueBatchArg['batch']['graph'], + runs: 1, + data: [], + origin: 'test', + destination: 'test', + }, +}); + +describe('executeEnqueue', () => { + it('dispatches enqueue flow and invokes success callback', async () => { + const store = createTestStore(); + const requestedAction = createAction('test/enqueue'); + const options = { prepend: false } as const; + const batchConfig = createBatchArg(options.prepend); + const onSuccess = vi.fn(); + const build = vi.fn(async () => ({ batchConfig })); + const prepareBatch = vi.fn(() => batchConfig); + + const result = await executeEnqueue({ + store, + options, + requestedAction, + build, + prepareBatch, + onSuccess, + log: { error: vi.fn() }, + }); + + expect(store.dispatch).toHaveBeenCalledWith(requestedAction()); + expect(build).toHaveBeenCalledWith({ store, options }); + expect(prepareBatch).toHaveBeenCalledWith({ store, options, buildResult: { batchConfig } }); + expect(onSuccess).toHaveBeenCalled(); + expect(result?.batchConfig).toBe(batchConfig); + }); + + it('stops when build returns null', async () => { + const store = createTestStore(); + const requestedAction = createAction('test/enqueue'); + const options = { prepend: true } as const; + const build = vi.fn(async () => null); + const prepareBatch = vi.fn(); + + const result = await executeEnqueue({ + store, + options, + requestedAction, + build, + prepareBatch, + log: { error: vi.fn() }, + }); + + expect(result).toBeNull(); + expect(build).toHaveBeenCalled(); + expect(prepareBatch).not.toHaveBeenCalled(); + }); + + it('invokes onError when build throws', async () => { + const store = createTestStore(); + const requestedAction = createAction('test/enqueue'); + const options = { prepend: false } as const; + const error = new Error('boom'); + const build = vi.fn(async () => { + throw error; + }); + const onError = vi.fn(); + const logError = vi.fn(); + + const result = await executeEnqueue({ + store, + options, + requestedAction, + build, + prepareBatch: vi.fn(), + onError, + log: { error: logError }, + }); + + expect(result).toBeNull(); + expect(onError).toHaveBeenCalledWith({ store, options, error }); + expect(logError).toHaveBeenCalled(); + }); +}); diff --git a/invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.test.ts b/invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.test.ts new file mode 100644 index 00000000000..df91dde300d --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/utils/graphBuilders.test.ts @@ -0,0 +1,81 @@ +import { describe, expect, it, vi } from 'vitest'; +import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { GraphBuilderArg } from 'features/nodes/util/graph/types'; +import type { Invocation } from 'services/api/types'; +import type { RootState } from 'app/store/store'; + +const mocks = vi.hoisted(() => { + const mockGraph: Graph = {} as Graph; + const mockPrompt = { id: 'prompt-node' } as Invocation<'string'>; + const asyncReturnValue = { g: mockGraph, positivePrompt: mockPrompt }; + const syncReturnValue = { g: mockGraph, positivePrompt: mockPrompt }; + + return { + asyncReturnValue, + syncReturnValue, + buildSDXLGraphMock: vi.fn().mockResolvedValue(asyncReturnValue), + buildImagen3GraphMock: vi.fn().mockReturnValue(syncReturnValue), + createDefaultBuilder: () => vi.fn().mockResolvedValue(asyncReturnValue), + }; +}); + +vi.mock('features/nodes/util/graph/generation/buildSDXLGraph', () => ({ + buildSDXLGraph: mocks.buildSDXLGraphMock, +})); +vi.mock('features/nodes/util/graph/generation/buildSD1Graph', () => ({ + buildSD1Graph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildSD3Graph', () => ({ + buildSD3Graph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildFLUXGraph', () => ({ + buildFLUXGraph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildFluxKontextGraph', () => ({ + buildFluxKontextGraph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildCogView4Graph', () => ({ + buildCogView4Graph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildImagen3Graph', () => ({ + buildImagen3Graph: mocks.buildImagen3GraphMock, +})); +vi.mock('features/nodes/util/graph/generation/buildImagen4Graph', () => ({ + buildImagen4Graph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildChatGPT4oGraph', () => ({ + buildChatGPT4oGraph: mocks.createDefaultBuilder(), +})); +vi.mock('features/nodes/util/graph/generation/buildGemini2_5Graph', () => ({ + buildGemini2_5Graph: mocks.createDefaultBuilder(), +})); + +import { buildGraphForBase } from './graphBuilders'; + +describe('buildGraphForBase', () => { + const baseArg: GraphBuilderArg = { + generationMode: 'txt2img', + state: {} as RootState, + manager: null, + }; + + it('awaits asynchronous graph builders', async () => { + const result = await buildGraphForBase('sdxl', baseArg); + + expect(result).toBe(mocks.asyncReturnValue); + expect(mocks.buildSDXLGraphMock).toHaveBeenCalledWith(baseArg); + }); + + it('supports synchronous graph builders', async () => { + const result = await buildGraphForBase('imagen3', baseArg); + + expect(result).toBe(mocks.syncReturnValue); + expect(mocks.buildImagen3GraphMock).toHaveBeenCalledWith(baseArg); + }); + + it('throws for unknown base models', async () => { + await expect(buildGraphForBase('unknown-model', baseArg)).rejects.toThrow( + 'No graph builders for base unknown-model' + ); + }); +});