Skip to content

Commit

Permalink
feat: add bailian multimodel support (#63)
Browse files Browse the repository at this point in the history
1. add dashscope multimodel support
  • Loading branch information
robinyeeh authored Oct 16, 2024
1 parent 3cc5cc4 commit 63ddbfa
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.alibaba.cloud.ai.dashscope.rag.DashScopeDocumentRetrieverOptions;
import com.alibaba.cloud.ai.dashscope.rag.DashScopeDocumentTransformerOptions;
import com.alibaba.cloud.ai.dashscope.rag.DashScopeStoreOptions;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.chat.metadata.Usage;
Expand Down Expand Up @@ -810,7 +811,7 @@ public Function(String description, String name, String jsonSchema) {
public record ChatCompletionRequest(@JsonProperty("model") String model,
@JsonProperty("input") ChatCompletionRequestInput input,
@JsonProperty("parameters") ChatCompletionRequestParameter parameters,
@JsonProperty("stream") Boolean stream) {
@JsonProperty("stream") Boolean stream, @JsonIgnore Boolean multiModel) {

/**
* Shortcut constructor for a chat completion request with the given messages and
Expand All @@ -819,7 +820,7 @@ public record ChatCompletionRequest(@JsonProperty("model") String model,
* @param input request input of chat.
*/
public ChatCompletionRequest(String model, ChatCompletionRequestInput input, Boolean stream) {
this(model, input, null, stream);
this(model, input, null, stream, false);
}
}

Expand Down Expand Up @@ -859,13 +860,14 @@ public record ChatCompletionRequestParameter(@JsonProperty("result_format") Stri
@JsonProperty("stop") List<Object> stop, @JsonProperty("enable_search") Boolean enableSearch,
@JsonProperty("incremental_output") Boolean incrementalOutput,
@JsonProperty("tools") List<FunctionTool> tools, @JsonProperty("tool_choice") Object toolChoice,
@JsonProperty("stream") Boolean stream) {
@JsonProperty("stream") Boolean stream,
@JsonProperty("vl_high_resolution_images") Boolean vlHighResolutionImages) {

/**
* shortcut constructor for chat request parameter
*/
public ChatCompletionRequestParameter() {
this(null, null, null, null, null, null, null, null, null, null, null, null, null, null);
this(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
}

/**
Expand Down Expand Up @@ -937,12 +939,28 @@ public record ChatCompletionMessage(@JsonProperty("content") Object rawContent,
*/
public String content() {
if (this.rawContent == null) {
return null;
return "";
}

if (this.rawContent instanceof String text) {
return text;
}
throw new IllegalStateException("The content is not a string!");

if (this.rawContent instanceof List list) {
if (list.isEmpty()) {
return "";
}

Object object = list.get(0);
if (object instanceof Map map) {
if (map.isEmpty() || map.get("text") == null) {
return "";
}

return map.get("text").toString();
}
}
throw new IllegalStateException("The content is not valid!");
}

/**
Expand Down Expand Up @@ -987,44 +1005,20 @@ public enum Role {
* An array of content parts with a defined type. Each MediaContent can be of
* either "text" or "image_url" type. Not both.
*
* @param type Content type, each can be of type text or image_url.
* @param text The text content of the message.
* @param imageUrl The image content of the message. You can pass multiple images
* by adding multiple image_url content parts. Image input is only supported when
* using the glm-4v model.
* @param image The image content of the message. You can pass multiple images
* @param video The image list of video. by adding multiple image_url content
* parts. Image input is only supported when using the glm-4v model.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public record MediaContent(@JsonProperty("type") String type, @JsonProperty("text") String text,
@JsonProperty("image_url") ImageUrl imageUrl) {

/**
* @param url Either a URL of the image or the base64 encoded image data. The
* base64 encoded image data must have a special prefix in the following
* format: "data:{mimetype};base64,{base64-encoded-image-data}".
* @param detail Specifies the detail level of the image.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) {

public ImageUrl(String url) {
this(url, null);
}
}

@JsonProperty("image") String image, @JsonProperty("video") List<String> video) {
/**
* Shortcut constructor for a text content.
* @param text The text content of the message.
*/
public MediaContent(String text) {
this("text", text, null);
}

/**
* Shortcut constructor for an image content.
* @param imageUrl The image content of the message.
*/
public MediaContent(ImageUrl imageUrl) {
this("image_url", null, imageUrl);
this("text", text, null, null);
}
}

Expand Down Expand Up @@ -1298,11 +1292,12 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
Assert.notNull(chatRequest, "The request body can not be null.");
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");

return this.restClient.post()
.uri("/api/v1/services/aigc/text-generation/generation")
.body(chatRequest)
.retrieve()
.toEntity(ChatCompletion.class);
String uri = "/api/v1/services/aigc/text-generation/generation";
if (chatRequest.multiModel()) {
uri = "/api/v1/services/aigc/multimodal-generation/generation";
}

return this.restClient.post().uri(uri).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
}

/**
Expand All @@ -1322,8 +1317,13 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
DashScopeAiStreamFunctionCallingHelper chunkMerger = new DashScopeAiStreamFunctionCallingHelper(
incrementalOutput);

String uri = "/api/v1/services/aigc/text-generation/generation";
if (chatRequest.multiModel()) {
uri = "/api/v1/services/aigc/multimodal-generation/generation";
}

return this.webClient.post()
.uri("/api/v1/services/aigc/text-generation/generation")
.uri(uri)
.header("X-DashScope-SSE", "enable")
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
.retrieve()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.*;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage.*;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage.MediaContent;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionOutput.Choice;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -52,6 +51,8 @@
*/
public class DashScopeChatModel extends AbstractToolCallSupport implements ChatModel {

public static final String MESSAGE_FORMAT = "messageFormat";

private static final Logger logger = LoggerFactory.getLogger(DashScopeChatModel.class);

/** Low-level access to the DashScope API */
Expand Down Expand Up @@ -256,16 +257,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
Object content = message.getContent();
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<MediaContent> contentList = new ArrayList<>(
List.of(new MediaContent(message.getContent())));

contentList.addAll(userMessage.getMedia()
.stream()
.map(media -> new MediaContent(new MediaContent.ImageUrl(
this.fromMediaData(media.getMimeType(), media.getData()))))
.toList());

content = contentList;
content = convertMediaContent(userMessage);
}
}

Expand Down Expand Up @@ -303,8 +295,41 @@ else if (message.getMessageType() == MessageType.TOOL) {
}
}).flatMap(List::stream).toList();

boolean multiModel = options.getMultiModel();
return new ChatCompletionRequest(options.getModel(), new ChatCompletionRequestInput(chatCompletionMessages),
toDashScopeRequestParameter(options, stream), stream);
toDashScopeRequestParameter(options, stream), stream, multiModel);
}

private List<MediaContent> convertMediaContent(UserMessage message) {
MessageFormat format = MessageFormat.IMAGE;
if (message.getMetadata().get(MESSAGE_FORMAT) instanceof MessageFormat messageFormat) {
format = messageFormat;
}

List<MediaContent> contentList = new ArrayList<>();
if (format == MessageFormat.VIDEO) {
MediaContent mediaContent = new MediaContent(message.getContent());
contentList.add(mediaContent);

List<String> mediaList = message.getMedia()
.stream()
.map(media -> this.fromMediaData(media.getMimeType(), media.getData()))
.toList();

contentList.add(new MediaContent("video", null, null, mediaList));
}
else {
MediaContent mediaContent = new MediaContent(message.getContent());
contentList.add(mediaContent);

contentList.addAll(message.getMedia()
.stream()
.map(media -> new MediaContent("image", null, this.fromMediaData(media.getMimeType(), media.getData()),
null))
.toList());
}

return contentList;
}

private String fromMediaData(MimeType mimeType, Object mediaContentData) {
Expand Down Expand Up @@ -340,7 +365,7 @@ private ChatCompletionRequestParameter toDashScopeRequestParameter(DashScopeChat
return new ChatCompletionRequestParameter("message", options.getSeed(), options.getMaxTokens(),
options.getTopP(), options.getTopK(), options.getRepetitionPenalty(), options.getPresencePenalty(),
options.getTemperature(), options.getStop(), options.getEnableSearch(), incrementalOutput,
options.getTools(), options.getToolChoice(), stream);
options.getTools(), options.getToolChoice(), stream, options.getVlHighResolutionImages());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ public class DashScopeChatOptions implements FunctionCallingOptions, ChatOptions
private @JsonProperty("tool_choice") Object toolChoice;

/**
* OpenAI Tool Function Callbacks to register with the ChatClient. For Prompt Options the
* this is to change token limitation to 16384 for vl model, only support for vl models
* including qwen-vl-max、qwen-vl-max-0809、qwen-vl-plus-0809.
*/
private @JsonProperty("vl_high_resolution_images") Boolean vlHighResolutionImages;

/**
* Tool Function Callbacks to register with the ChatClient. For Prompt Options the
* functionCallbacks are automatically enabled for the duration of the prompt execution. For
* Default Options the functionCallbacks are registered but disabled by default. Use the
* enableFunctions to set the functions from the registry to be used by the ChatClient chat
Expand All @@ -110,6 +116,11 @@ public class DashScopeChatOptions implements FunctionCallingOptions, ChatOptions
*/
@NestedConfigurationProperty @JsonIgnore private Set<String> functions = new HashSet<>();

/**
* Indicate if the request is multi model
*/
private @JsonProperty("multi_model") Boolean multiModel = false;

@NestedConfigurationProperty
@JsonIgnore
private Map<String, Object> toolContext;
Expand Down Expand Up @@ -253,6 +264,22 @@ public void setIncrementalOutput(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}

public Boolean getVlHighResolutionImages() {
return vlHighResolutionImages;
}

public void setVlHighResolutionImages(Boolean vlHighResolutionImages) {
this.vlHighResolutionImages = vlHighResolutionImages;
}

public Boolean getMultiModel() {
return multiModel;
}

public void setMultiModel(Boolean multiModel) {
this. multiModel = multiModel;
}

public static DashscopeChatOptionsBuilder builder() {
return new DashscopeChatOptionsBuilder();
}
Expand Down Expand Up @@ -352,6 +379,16 @@ public DashscopeChatOptionsBuilder withToolContext(Map<String, Object> toolConte
return this;
}

public DashscopeChatOptionsBuilder withVlHighResolutionImages(Boolean vlHighResolutionImages) {
this.options.vlHighResolutionImages = vlHighResolutionImages;
return this;
}

public DashscopeChatOptionsBuilder withMultiModel(Boolean multiModel) {
this.options.multiModel = multiModel;
return this;
}

public DashScopeChatOptions build() {
return this.options;
}
Expand All @@ -363,6 +400,7 @@ public static DashScopeChatOptions fromOptions(DashScopeChatOptions fromOptions)
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTopK(fromOptions.getTopK())
.withSeed(fromOptions.getSeed())
.withStop(fromOptions.getStop())
.withStream(fromOptions.getStream())
.withEnableSearch(fromOptions.enableSearch)
Expand All @@ -372,6 +410,8 @@ public static DashScopeChatOptions fromOptions(DashScopeChatOptions fromOptions)
.withRepetitionPenalty(fromOptions.getRepetitionPenalty())
.withTools(fromOptions.getTools())
.withToolContext(fromOptions.getToolContext())
.withMultiModel(fromOptions.getMultiModel())
.withVlHighResolutionImages(fromOptions.getVlHighResolutionImages())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.dashscope.chat;

/**
* Title message format.<br>
* Description message format.<br>
*
* @author yuanci.ytb
* @since 1.0.0-M2
*/

public enum MessageFormat {

/**
* image format
*/
IMAGE,

/**
* video format
*/
VIDEO

}
Loading

0 comments on commit 63ddbfa

Please sign in to comment.