Skip to content

Commit

Permalink
Prodia: unified SDXL support, with model list, priority, advanced set…
Browse files Browse the repository at this point in the history
…tings, resolution, default to R.V.5
  • Loading branch information
enricoros committed Oct 27, 2023
1 parent 6e7aa71 commit a8839b7
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 122 deletions.
31 changes: 2 additions & 29 deletions src/apps/chat/editors/image-generate.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import { apiAsync } from '~/common/util/trpc.client';
import { prodiaDefaultModelId } from '~/modules/prodia/prodia.models';
import { useProdiaStore } from '~/modules/prodia/store-prodia';
import { prodiaGenerateImage } from '~/modules/prodia/prodia.client';

import { useChatStore } from '~/common/state/store-chats';

Expand All @@ -26,32 +24,7 @@ export async function runImageGenerationUpdatingState(conversationId: string, im
const { editMessage } = useChatStore.getState();

try {

const {
prodiaApiKey: prodiaKey, prodiaModelId,
prodiaNegativePrompt: negativePrompt, prodiaSteps: steps, prodiaCfgScale: cfgScale,
prodiaAspectRatio: aspectRatio, prodiaUpscale: upscale,
prodiaSeed: seed,
} = useProdiaStore.getState();

// Run the image generation count times in parallel
const imageUrls = await Promise.all(
Array(count).fill(undefined).map(async () => {
const { imageUrl } = await apiAsync.prodia.imagine.query({
...(!!prodiaKey && { prodiaKey }),
prodiaModel: prodiaModelId || prodiaDefaultModelId,
prompt: imageText,
...(!!negativePrompt && { negativePrompt }),
...(!!steps && { steps }),
...(!!cfgScale && { cfgScale }),
...(!!aspectRatio && aspectRatio !== 'square' && { aspectRatio }),
...((upscale && { upscale })),
...(!!seed && { seed }),
});

return imageUrl;
}),
);
const imageUrls = await prodiaGenerateImage(count, imageText);

// Concatenate all the resulting URLs and update the assistant message with these URLs
const allImageUrls = imageUrls.join('\n');
Expand Down
93 changes: 74 additions & 19 deletions src/modules/prodia/ProdiaSettings.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as React from 'react';
import { shallow } from 'zustand/shallow';

import { Box, CircularProgress, FormControl, FormHelperText, FormLabel, Input, Option, Radio, RadioGroup, Select, Slider, Stack, Switch, Tooltip } from '@mui/joy';
import { Box, Chip, CircularProgress, FormControl, FormHelperText, FormLabel, Input, Option, Radio, RadioGroup, Select, Slider, Stack, Switch, Tooltip } from '@mui/joy';
import CropSquareIcon from '@mui/icons-material/CropSquare';
import FormatPaintIcon from '@mui/icons-material/FormatPaint';
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
Expand All @@ -13,22 +13,28 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey';
import { InlineError } from '~/common/components/InlineError';
import { apiQuery } from '~/common/util/trpc.client';
import { settingsGap } from '~/common/theme';
import { useToggleableBoolean } from '~/common/util/useToggleableBoolean';

import { DEFAULT_PRODIA_RESOLUTION, HARDCODED_PRODIA_RESOLUTIONS, useProdiaStore } from './store-prodia';
import { isValidProdiaApiKey, requireUserKeyProdia } from './prodia.client';
import { prodiaDefaultModelId } from './prodia.models';
import { useProdiaStore } from './store-prodia';


export function ProdiaSettings() {

// local state
const advanced = useToggleableBoolean();

// external state
const { apiKey, setApiKey, modelId, setModelId, negativePrompt, setNegativePrompt, steps, setSteps, cfgScale, setCfgScale, prodiaAspectRatio, setProdiaAspectRatio, upscale, setUpscale, seed, setSeed } = useProdiaStore(state => ({
const { apiKey, setApiKey, modelId, setModelId, modelGen, setModelGen, negativePrompt, setNegativePrompt, steps, setSteps, cfgScale, setCfgScale, prodiaAspectRatio, setProdiaAspectRatio, upscale, setUpscale, prodiaResolution, setProdiaResolution, seed, setSeed } = useProdiaStore(state => ({
apiKey: state.prodiaApiKey, setApiKey: state.setProdiaApiKey,
modelId: state.prodiaModelId, setModelId: state.setProdiaModelId,
modelGen: state.prodiaModelGen, setModelGen: state.setProdiaModelGen,
negativePrompt: state.prodiaNegativePrompt, setNegativePrompt: state.setProdiaNegativePrompt,
steps: state.prodiaSteps, setSteps: state.setProdiaSteps,
cfgScale: state.prodiaCfgScale, setCfgScale: state.setProdiaCfgScale,
prodiaAspectRatio: state.prodiaAspectRatio, setProdiaAspectRatio: state.setProdiaAspectRatio,
upscale: state.prodiaUpscale, setUpscale: state.setProdiaUpscale,
prodiaResolution: state.prodiaResolution, setProdiaResolution: state.setProdiaResolution,
seed: state.prodiaSeed, setSeed: state.setProdiaSeed,
}), shallow);

Expand All @@ -41,10 +47,31 @@ export function ProdiaSettings() {
staleTime: 1000 * 60 * 60, // 1 hour
});

const handleModelChange = (e: any, value: string | null) => value && setModelId(value);
// [effect] if no model is selected, auto-select the first
React.useEffect(() => {
if (modelsData && modelsData.models && !modelId) {
setModelId(modelsData.models[0].id);
setModelGen(modelsData.models[0].gen);
}
}, [modelsData, modelId, setModelId, setModelGen]);

const handleModelChange = (_event: any, value: string | null) => {
if (value) {
const prodiaModel = modelsData?.models?.find(model => model.id === value) ?? null;
if (prodiaModel) {
setModelId(prodiaModel.id);
setModelGen(prodiaModel.gen);
}
}
};

const handleResolutionChange = (_event: any, value: string | null) => value && setProdiaResolution(value);

const colWidth = 150;

// reference the currently selected model
const selectedIsXL = modelGen === 'sdxl';

return (
<Stack direction='column' sx={{ gap: settingsGap }}>

Expand All @@ -67,18 +94,19 @@ export function ProdiaSettings() {
</FormLabel>
<Select
variant='outlined' placeholder={isValidKey ? 'Select a model' : 'Enter API Key'}
value={modelId || prodiaDefaultModelId} onChange={handleModelChange}
startDecorator={<FormatPaintIcon />}
value={modelId} onChange={handleModelChange}
startDecorator={<FormatPaintIcon sx={{ display: { xs: 'none', sm: 'inherit' } }} />}
endDecorator={isValidKey && loadingModels && <CircularProgress size='sm' />}
indicator={<KeyboardArrowDownIcon />}
slotProps={{
root: { sx: { width: '100%' } },
indicator: { sx: { opacity: 0.5 } },
button: { sx: { whiteSpace: 'inherit' } },
}}
>
{modelsData && modelsData.models?.map((model, idx) => (
<Option key={'prodia-model-' + idx} value={model.id}>
{model.label}
<Option key={'prodia-model-' + idx} value={model.id} sx={model.priority ? { fontWeight: 500 } : undefined}>
{model.gen === 'sdxl' && <Chip size='sm' variant='outlined'>XL</Chip>} {model.label}
</Option>
))}
</Select>
Expand Down Expand Up @@ -117,7 +145,7 @@ export function ProdiaSettings() {
</Box>
<Slider
aria-label='Image Generation steps' valueLabelDisplay='auto'
value={steps} onChange={(e, value) => setSteps(value as number)}
value={steps} onChange={(_event, value) => setSteps(value as number)}
min={10} max={50} step={1} defaultValue={25}
sx={{ width: '100%' }}
/>
Expand All @@ -136,16 +164,38 @@ export function ProdiaSettings() {
</Box>
<Slider
aria-label='Image Generation Guidance' valueLabelDisplay='auto'
value={cfgScale} onChange={(e, value) => setCfgScale(value as number)}
value={cfgScale} onChange={(_event, value) => setCfgScale(value as number)}
min={1} max={15} step={0.5} defaultValue={7}
sx={{ width: '100%' }}
/>
</FormControl>

<FormControl orientation='horizontal' sx={{ justifyContent: 'space-between' }}>
{advanced.on && selectedIsXL && <FormControl orientation='horizontal' sx={{ justifyContent: 'space-between', alignItems: 'center' }}>
<FormLabel sx={{ minWidth: colWidth }}>
[SDXL] Resolution
</FormLabel>
<Select
variant='outlined'
value={prodiaResolution || DEFAULT_PRODIA_RESOLUTION} onChange={handleResolutionChange}
// indicator={<KeyboardArrowDownIcon />}
slotProps={{
root: { sx: { width: '100%' } },
indicator: { sx: { opacity: 0.5 } },
button: { sx: { whiteSpace: 'inherit' } },
}}
>
{HARDCODED_PRODIA_RESOLUTIONS.map((resolution) => (
<Option key={'sdxl-res-' + resolution} value={resolution}>
{resolution.replace('x', ' x ')}
</Option>
))}
</Select>
</FormControl>}

{advanced.on && !selectedIsXL && <FormControl orientation='horizontal' sx={{ justifyContent: 'space-between' }}>
<Box>
<FormLabel sx={{ minWidth: colWidth }}>
Aspect Ratio
[SD] Aspect Ratio
</FormLabel>
<FormHelperText>
{prodiaAspectRatio === 'square' ? 'Square' : prodiaAspectRatio === 'portrait' ? 'Portrait' : 'Landscape'}
Expand All @@ -156,12 +206,12 @@ export function ProdiaSettings() {
<Radio value='portrait' label={<StayPrimaryPortraitIcon sx={{ width: 25, height: 24, mt: -0.25 }} />} />
<Radio value='landscape' label={<StayPrimaryLandscapeIcon sx={{ width: 25, height: 24, mt: -0.25 }} />} />
</RadioGroup>
</FormControl>
</FormControl>}

<FormControl orientation='horizontal' sx={{ justifyContent: 'space-between' }}>
{advanced.on && !selectedIsXL && <FormControl orientation='horizontal' sx={{ justifyContent: 'space-between' }}>
<Box>
<FormLabel sx={{ minWidth: colWidth }}>
Upscale <InfoOutlinedIcon sx={{ mx: 0.5 }} />
[SD] Upscale <InfoOutlinedIcon sx={{ mx: 0.5 }} />
</FormLabel>
<FormHelperText>
{upscale ? '1024px' : 'Default'}
Expand All @@ -170,9 +220,9 @@ export function ProdiaSettings() {
<Switch checked={upscale} onChange={(e) => setUpscale(e.target.checked)}
endDecorator={upscale ? '2x' : 'Off'}
slotProps={{ endDecorator: { sx: { minWidth: 26 } } }} />
</FormControl>
</FormControl>}

<FormControl orientation='horizontal' sx={{ justifyContent: 'space-between', alignItems: 'center' }}>
{advanced.on && <FormControl orientation='horizontal' sx={{ justifyContent: 'space-between', alignItems: 'center' }}>
<Box>
<Tooltip title='Set value for reproducible images. Different by default.'>
<FormLabel sx={{ minWidth: colWidth }}>
Expand All @@ -195,7 +245,12 @@ export function ProdiaSettings() {
}}
sx={{ width: '100%' }}
/>
</FormControl>
</FormControl>}

<FormLabel onClick={advanced.toggle} sx={{ textDecoration: 'underline', cursor: 'pointer' }}>
{advanced.on ? 'Hide Advanced' : 'Advanced'}
{/*{selectedIsXL ? 'XL' : ''} Settings*/}
</FormLabel>

</Stack>
);
Expand Down
41 changes: 41 additions & 0 deletions src/modules/prodia/prodia.client.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
import { apiAsync } from '~/common/util/trpc.client';

import { useProdiaStore } from './store-prodia';


export const requireUserKeyProdia = !process.env.HAS_SERVER_KEY_PRODIA;

export const canUseProdia = (): boolean => !!useProdiaStore.getState().prodiaModelId || !requireUserKeyProdia;

export const isValidProdiaApiKey = (apiKey?: string) => !!apiKey && apiKey.trim()?.length >= 36;

export const CmdRunProdia: string[] = ['/imagine', '/img'];


export async function prodiaGenerateImage(count: number, imageText: string) {
// use the most current model and settings
const {
prodiaApiKey: prodiaKey, prodiaModelId, prodiaModelGen,
prodiaNegativePrompt: negativePrompt, prodiaSteps: steps, prodiaCfgScale: cfgScale,
prodiaAspectRatio: aspectRatio, prodiaUpscale: upscale,
prodiaResolution: resolution,
prodiaSeed: seed,
} = useProdiaStore.getState();

// Run the image generation 'count' times in parallel
const imageUrls: string[] = await Promise.all(
// using an array of 'count' number of promises
Array(count).fill(undefined).map(async () => {

const { imageUrl } = await apiAsync.prodia.imagine.query({
...(!!prodiaKey && { prodiaKey }),
prodiaModel: prodiaModelId || 'Realistic_Vision_V5.0.safetensors [614d1063]', // data versioning fix
prodiaGen: prodiaModelGen || 'sd', // data versioning fix
prompt: imageText,
...(!!negativePrompt && { negativePrompt }),
...(!!steps && { steps }),
...(!!cfgScale && { cfgScale }),
...(!!aspectRatio && aspectRatio !== 'square' && { aspectRatio }),
...(upscale && { upscale }),
...(!!resolution && { resolution }),
...(!!seed && { seed }),
});

return imageUrl;
}),
);

// Return the resulting image URLs
return imageUrls;
}
Loading

1 comment on commit a8839b7

@vercel
Copy link

@vercel vercel bot commented on a8839b7 Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

big-agi – ./

big-agi-enricoros.vercel.app
big-agi-git-main-enricoros.vercel.app
get.big-agi.com

Please sign in to comment.