Skip to content

Commit

Permalink
Merge pull request #96 from Cirilla-zmh/lumian/fea_dashscope_o11y
Browse files Browse the repository at this point in the history
feat: add DashScopeChatModel call and stream observability
  • Loading branch information
chickenlj authored Nov 9, 2024
2 parents 7ba17a8 + 78b683c commit 1a63340
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 81 deletions.
6 changes: 6 additions & 0 deletions spring-ai-alibaba-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@
<artifactId>reactor-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@
import java.util.concurrent.ConcurrentHashMap;

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.chat.observation.DashScopeChatModelObservationConvention;
import com.alibaba.cloud.ai.dashscope.metadata.DashScopeAiUsage;
import com.alibaba.cloud.ai.observation.conventions.AiProvider;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import reactor.core.publisher.Flux;

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.*;
Expand All @@ -25,10 +35,6 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
Expand All @@ -55,15 +61,27 @@ public class DashScopeChatModel extends AbstractToolCallSupport implements ChatM

private static final Logger logger = LoggerFactory.getLogger(DashScopeChatModel.class);

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DashScopeChatModelObservationConvention();

/** Low-level access to the DashScope API */
private final DashScopeApi dashscopeApi;

/** The retry template used to retry the OpenAI API calls. */
public final RetryTemplate retryTemplate;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/** The default options used for the chat completion requests. */
private DashScopeChatOptions defaultOptions;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public DashScopeChatModel(DashScopeApi dashscopeApi) {
this(dashscopeApi,
DashScopeChatOptions.builder()
Expand All @@ -78,6 +96,12 @@ public DashScopeChatModel(DashScopeApi dashscopeApi, DashScopeChatOptions option

public DashScopeChatModel(DashScopeApi dashscopeApi, DashScopeChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
this(dashscopeApi, options, functionCallbackContext, retryTemplate, ObservationRegistry.NOOP);
}

public DashScopeChatModel(DashScopeApi dashscopeApi, DashScopeChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
super(functionCallbackContext);
Assert.notNull(dashscopeApi, "DashScopeApi must not be null");
Assert.notNull(options, "Options must not be null");
Expand All @@ -86,35 +110,52 @@ public DashScopeChatModel(DashScopeApi dashscopeApi, DashScopeChatOptions option
this.dashscopeApi = dashscopeApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public ChatResponse call(Prompt prompt) {
DashScopeApi.ChatCompletionRequest request = createRequest(prompt, false);

ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
.execute(ctx -> this.dashscopeApi.chatCompletionEntity(request));
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.DASHSCOPE.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

var chatCompletion = completionEntity.getBody();
ChatResponse chatResponse = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
DashScopeApi.ChatCompletionRequest request = createRequest(prompt, false);

if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
.execute(ctx -> this.dashscopeApi.chatCompletionEntity(request));

var chatCompletion = completionEntity.getBody();

if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

List<ChatCompletionOutput.Choice> choices = chatCompletion.output().choices();
List<ChatCompletionOutput.Choice> choices = chatCompletion.output().choices();

List<Generation> generations = choices.stream().map(choice -> {
List<Generation> generations = choices.stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion.requestId(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
Map<String, Object> metadata = Map.of(
"id", chatCompletion.requestId(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatResponse response = new ChatResponse(generations, from(completionEntity.getBody()));

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
observationContext.setResponse(response);

return response;
});

if (isToolCall(chatResponse,
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
Expand All @@ -134,62 +175,83 @@ public ChatOptions getDefaultOptions() {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
ChatCompletionRequest request = createRequest(prompt, true);

Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.dashscopeApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.dashscopeApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.DASHSCOPE.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String requestId = chatCompletion2.requestId();

// @formatter:off
List<Generation> generations = chatCompletion2.output().choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(requestId, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.requestId(),
"role", roleMap.getOrDefault(requestId, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
}
else {
return new ChatResponse(generations);
}
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String requestId = chatCompletion2.requestId();
}));

// @formatter:off
List<Generation> generations = chatCompletion2.output().choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(requestId, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.requestId(),
"role", roleMap.getOrDefault(requestId, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
}
else {
return new ChatResponse(generations);
}
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {

if (isToolCall(response,
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

}));

return chatResponse.flatMap(response -> {

if (isToolCall(response,
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,30 @@ public class DashScopeChatOptions implements FunctionCallingOptions, ChatOptions
@JsonIgnore
private Map<String, Object> toolContext;

@Override
public String getModel() {
return model;
}@Override public Double getFrequencyPenalty() {
}

@Override
public Double getFrequencyPenalty() {
return null;
}@Override public Integer getMaxTokens() {
}

@Override
public Integer getMaxTokens() {
return null;
}@Override public Double getPresencePenalty() {
}

@Override
public Double getPresencePenalty() {
return null;
}@Override public List<String> getStopSequences() {
}

@Override
public List<String> getStopSequences() {
return null;
}
}

public void setModel(String model) {
this.model = model;
Expand All @@ -161,9 +174,12 @@ public void setTemperature(Double temperature) {
@Override
public Double getTopP() {
return this.topP;
}@Override public ChatOptions copy() {
}

@Override
public ChatOptions copy() {
return DashScopeChatOptions.fromOptions(this);
}
}

public void setTopP(Double topP) {
this.topP = topP;
Expand Down Expand Up @@ -264,7 +280,7 @@ public void setIncrementalOutput(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}

public Boolean getVlHighResolutionImages() {
public Boolean getVlHighResolutionImages() {
return vlHighResolutionImages;
}

Expand Down
Loading

0 comments on commit 1a63340

Please sign in to comment.