Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): more batch data types #7545

Merged
merged 39 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
22316ec
fix(ui): typo in error message for image collection fields
psychedelicious Jan 9, 2025
089e6dd
feat(nodes): add string batch node
psychedelicious Jan 9, 2025
072a7c4
docs(ui): improved comments for image batch node special handling
psychedelicious Jan 9, 2025
0d0936a
tweak(ui): image field collection input component styling
psychedelicious Jan 9, 2025
e42ad66
refactor(ui): streamline image field collection input logic, support …
psychedelicious Jan 10, 2025
a8c8544
feat(ui): support string batches
psychedelicious Jan 10, 2025
aa82044
feat(nodes): add integer batch node
psychedelicious Jan 10, 2025
123ea6c
feat(ui): support integer batches
psychedelicious Jan 10, 2025
f35a2b7
feat(nodes): add float batch node
psychedelicious Jan 10, 2025
e844940
feat(ui): add template validation for string collection items
psychedelicious Jan 10, 2025
28ea247
feat(ui): add template validation for integer collection items
psychedelicious Jan 10, 2025
948ef90
refactor(ui): abstract out field validators
psychedelicious Jan 10, 2025
2a3b562
fix(ui): typo
psychedelicious Jan 10, 2025
b763da9
refactor(ui): abstract out helper to add batch data
psychedelicious Jan 10, 2025
3f54871
feat(ui): support float batches
psychedelicious Jan 10, 2025
4ddd2ca
feat(ui): validate string item lengths
psychedelicious Jan 10, 2025
fb7a842
feat(ui): validate number item multipleOf
psychedelicious Jan 10, 2025
395b4f3
chore(ui): typegen
psychedelicious Jan 10, 2025
1c2656f
tidy(ui): use zod typeguard builder util for fields
psychedelicious Jan 10, 2025
e5abb90
chore(ui): lint
psychedelicious Jan 10, 2025
cc74f00
fix(ui): float batch data creation
psychedelicious Jan 10, 2025
01e89be
fix(nodes): allow batch datum items to mix ints and floats
psychedelicious Jan 10, 2025
fb8717e
feat(ui): rough out number generators for number collection fields
psychedelicious Jan 10, 2025
389830b
perf(ui): memoize selector in workflows
psychedelicious Jan 13, 2025
645db87
fix(ui): filter out batch nodes when checking readiness on workflows tab
psychedelicious Jan 13, 2025
56ae064
fix(ui): do not set number collection field to undefined when removin…
psychedelicious Jan 13, 2025
7ef86fe
feat(ui): number collection generator supports floats
psychedelicious Jan 13, 2025
cf56f0d
feat(nodes): add default value for batch nodes
psychedelicious Jan 13, 2025
1d74860
tidy(ui): abstract out batch detection logic
psychedelicious Jan 13, 2025
d0c9389
feat(ui): more batch generator stuff
psychedelicious Jan 13, 2025
a52eb63
Revert "feat(ui): more batch generator stuff"
psychedelicious Jan 13, 2025
45d7ffd
Revert "feat(ui): number collection generator supports floats"
psychedelicious Jan 13, 2025
330c9a7
Revert "feat(ui): rough out number generators for number collection f…
psychedelicious Jan 13, 2025
a0803d5
feat(ui): add number range generators
psychedelicious Jan 13, 2025
7dee7fe
fix(ui): translation key
psychedelicious Jan 13, 2025
fb2ff00
feat(ui): string collection batch items are input not textarea
psychedelicious Jan 13, 2025
9a55b1d
tweak(ui): number collection styling
psychedelicious Jan 13, 2025
8efc010
chore(ui): typegen
psychedelicious Jan 13, 2025
8a29a11
chore(ui): lint
psychedelicious Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,18 +544,86 @@ def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")

def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""

strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")

def invoke(self, context: InvocationContext) -> StringOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""

integers: list[int] = InputField(
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")

def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""

floats: list[float] = InputField(
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")

def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")
10 changes: 9 additions & 1 deletion invokeai/app/services/session_queue/session_queue_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,16 @@ def validate_types(cls, v: Optional[BatchDataCollection]):
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue

# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue

# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
first_item_type = type(datum.items[0])
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
Expand Down
27 changes: 21 additions & 6 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@
"none": "None",
"new": "New",
"generating": "Generating",
"warnings": "Warnings"
"warnings": "Warnings",
"start": "Start",
"count": "Count",
"step": "Step",
"values": "Values"
},
"hrf": {
"hrf": "High Resolution Fix",
Expand Down Expand Up @@ -989,7 +993,11 @@
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
"saveToGallery": "Save To Gallery"
"saveToGallery": "Save To Gallery",
"addItem": "Add Item",
"generateValues": "Generate Values",
"floatRangeGenerator": "Float Range Generator",
"integerRangeGenerator": "Integer Range Generator"
},
"parameters": {
"aspect": "Aspect",
Expand Down Expand Up @@ -1024,11 +1032,18 @@
"addingImagesTo": "Adding images to",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
"missingInputForField": "missing input",
"missingNodeTemplate": "Missing node template",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
"collectionEmpty": "empty collection",
"collectionTooFewItems": "too few items, minimum {{minItems}}",
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
"collectionStringTooLong": "too long, max {{maxLength}}",
"collectionStringTooShort": "too short, min {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
Expand Down
4 changes: 4 additions & 0 deletions invokeai/frontend/web/src/app/components/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { FloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
import { IntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
Expand Down Expand Up @@ -110,6 +112,8 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />
<FloatRangeGeneratorModal />
<IntegerRangeGeneratorModal />
</ErrorBoundary>
);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import type { ImageField } from 'features/nodes/types/common';
import {
isFloatFieldCollectionInputInstance,
isImageFieldCollectionInputInstance,
isIntegerFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
Expand Down Expand Up @@ -33,29 +40,92 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =

const data: Batch['data'] = [];

// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const addBatchDataCollectionItem = (edges: InvocationNodeEdge[], items?: ImageField[] | string[] | number[]) => {
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
for (const edge of edges) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: images.value,
items,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
};

// Grab image batch nodes for special handling
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');

for (const node of imageBatchNodes) {
// Satisfy TS
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}

// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
addBatchDataCollectionItem(edgesFromImageBatch, images.value);
}

// Grab string batch nodes for special handling
const stringBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'string_batch');
for (const node of stringBatchNodes) {
// Satisfy TS
const strings = node.data.inputs['strings'];
if (!isStringFieldCollectionInputInstance(strings)) {
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
break;
}

// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, strings.value);
}

// Grab integer batch nodes for special handling
const integerBatchNodes = nodes.nodes
.filter(isInvocationNode)
.filter((node) => node.data.type === 'integer_batch');
for (const node of integerBatchNodes) {
// Satisfy TS
const integers = node.data.inputs['integers'];
if (!isIntegerFieldCollectionInputInstance(integers)) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
break;
}
if (!integers.value) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
break;
}

// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, integers.value);
}

// Grab float batch nodes for special handling
const floatBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'float_batch');
for (const node of floatBatchNodes) {
// Satisfy TS
const floats = node.data.inputs['floats'];
if (!isFloatFieldCollectionInputInstance(floats)) {
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
break;
}
if (!floats.value) {
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
break;
}

// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, floats.value);
}

const batchConfig: BatchConfig = {
Expand Down
28 changes: 22 additions & 6 deletions invokeai/frontend/web/src/features/dnd/dnd.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
Expand All @@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
import type { BoardId } from 'features/gallery/store/types';
import {
addImagesToBoard,
addImagesToNodeImageFieldCollectionAction,
createNewCanvasEntityFromImage,
removeImagesFromBoard,
replaceCanvasEntityObjectsWithImage,
Expand All @@ -19,10 +19,14 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';

const log = logger('dnd');

type RecordUnknown = Record<string | symbol, unknown>;

type DndData<
Expand Down Expand Up @@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
}

const { fieldIdentifier } = targetData.payload;
const imageDTOs: ImageDTO[] = [];

const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);

if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}

const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];

if (singleImageDndSource.typeGuard(sourceData)) {
imageDTOs.push(sourceData.payload.imageDTO);
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
} else {
imageDTOs.push(...sourceData.payload.imageDTOs);
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
}

addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
},
};
//#endregion
Expand Down
Loading
Loading