⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ public abstract class BaseAgent {
*/
private BaseAgent parentAgent;

private final List<? extends BaseAgent> subAgents;
private final ImmutableList<? extends BaseAgent> subAgents;

private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
private final ImmutableList<? extends BeforeAgentCallback> beforeAgentCallback;
private final ImmutableList<? extends AfterAgentCallback> afterAgentCallback;

/**
* Creates a new BaseAgent.
Expand All @@ -82,9 +82,13 @@ public BaseAgent(
this.name = name;
this.description = description;
this.parentAgent = null;
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents);
this.beforeAgentCallback =
beforeAgentCallback == null
? ImmutableList.of()
: ImmutableList.copyOf(beforeAgentCallback);
this.afterAgentCallback =
afterAgentCallback == null ? ImmutableList.of() : ImmutableList.copyOf(afterAgentCallback);

// Establish parent relationships for all sub-agents if needed.
for (BaseAgent subAgent : this.subAgents) {
Expand Down Expand Up @@ -144,38 +148,40 @@ public BaseAgent rootAgent() {
/**
* Finds an agent (this or descendant) by name.
*
* @return the agent or descendant with the given name, or {@code null} if not found.
* @return an {@link Optional} containing the agent or descendant with the given name, or {@link
* Optional#empty()} if not found.
*/
public BaseAgent findAgent(String name) {
public Optional<BaseAgent> findAgent(String name) {
if (this.name().equals(name)) {
return this;
return Optional.of(this);
}
return findSubAgent(name);
}

/** Recursively search sub agent by name. */
public @Nullable BaseAgent findSubAgent(String name) {
for (BaseAgent subAgent : subAgents) {
if (subAgent.name().equals(name)) {
return subAgent;
}
BaseAgent result = subAgent.findSubAgent(name);
if (result != null) {
return result;
}
}
return null;
/**
* Recursively search sub agent by name.
*
* @return an {@link Optional} containing the sub agent with the given name, or {@link
* Optional#empty()} if not found.
*/
public Optional<BaseAgent> findSubAgent(String name) {
return subAgents.stream()
.map(
subAgent ->
subAgent.name().equals(name) ? Optional.of(subAgent) : subAgent.findSubAgent(name))
.flatMap(Optional::stream)
.findFirst();
}

public List<? extends BaseAgent> subAgents() {
public ImmutableList<? extends BaseAgent> subAgents() {
return subAgents;
}

public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
public ImmutableList<? extends BeforeAgentCallback> beforeAgentCallback() {
return beforeAgentCallback;
}

public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
public ImmutableList<? extends AfterAgentCallback> afterAgentCallback() {
return afterAgentCallback;
}

Expand All @@ -184,17 +190,17 @@ public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
*
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
return beforeAgentCallback.orElse(ImmutableList.of());
public ImmutableList<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
return beforeAgentCallback;
}

/**
* The resolved afterAgentCallback field as a list.
*
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
return afterAgentCallback.orElse(ImmutableList.of());
public ImmutableList<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
return afterAgentCallback;
}

/**
Expand Down Expand Up @@ -239,8 +245,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
() ->
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(),
beforeAgentCallback.orElse(ImmutableList.of())),
invocationContext.pluginManager(), beforeAgentCallback),
invocationContext)
.flatMapPublisher(
beforeEventOpt -> {
Expand All @@ -257,7 +262,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
afterAgentCallback.orElse(ImmutableList.of())),
afterAgentCallback),
invocationContext)
.flatMapPublisher(Flowable::fromOptional));

Expand Down
7 changes: 4 additions & 3 deletions core/src/main/java/com/google/adk/agents/BaseAgentConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.adk.agents;

import java.util.Collections;
import java.util.List;

/**
Expand Down Expand Up @@ -132,23 +133,23 @@ public String agentClass() {
}

public List<AgentRefConfig> subAgents() {
return subAgents;
return subAgents == null ? Collections.emptyList() : subAgents;
}

public void setSubAgents(List<AgentRefConfig> subAgents) {
this.subAgents = subAgents;
}

public List<CallbackRef> beforeAgentCallbacks() {
return beforeAgentCallbacks;
return beforeAgentCallbacks == null ? Collections.emptyList() : beforeAgentCallbacks;
}

public void setBeforeAgentCallbacks(List<CallbackRef> beforeAgentCallbacks) {
this.beforeAgentCallbacks = beforeAgentCallbacks;
}

public List<CallbackRef> afterAgentCallbacks() {
return afterAgentCallbacks;
return afterAgentCallbacks == null ? Collections.emptyList() : afterAgentCallbacks;
}

public void setAfterAgentCallbacks(List<CallbackRef> afterAgentCallbacks) {
Expand Down
45 changes: 21 additions & 24 deletions core/src/main/java/com/google/adk/agents/CallbackUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.reactivex.rxjava3.core.Maybe;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -38,45 +37,43 @@ public final class CallbackUtil {
* Normalizes before-agent callbacks.
*
* @param beforeAgentCallback Callback list (sync or async).
* @return normalized async callbacks, or null if input is null.
* @return normalized async callbacks, or empty list if input is null.
*/
@CanIgnoreReturnValue
public static @Nullable ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
public static ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
List<BeforeAgentCallbackBase> beforeAgentCallback) {
if (beforeAgentCallback == null) {
return null;
} else if (beforeAgentCallback.isEmpty()) {
if (beforeAgentCallback == null || beforeAgentCallback.isEmpty()) {
return ImmutableList.of();
} else {
ImmutableList.Builder<BeforeAgentCallback> builder = ImmutableList.builder();
for (BeforeAgentCallbackBase callback : beforeAgentCallback) {
if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) {
builder.add(beforeAgentCallbackInstance);
} else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) {
builder.add(
(callbackContext) ->
Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext)));
} else {
logger.warn(
"Invalid beforeAgentCallback callback type: {}. Ignoring this callback.",
callback.getClass().getName());
}
}

ImmutableList.Builder<BeforeAgentCallback> builder = ImmutableList.builder();
for (BeforeAgentCallbackBase callback : beforeAgentCallback) {
if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) {
builder.add(beforeAgentCallbackInstance);
} else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) {
builder.add(
(callbackContext) ->
Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext)));
} else {
logger.warn(
"Invalid beforeAgentCallback callback type: {}. Ignoring this callback.",
callback.getClass().getName());
}
return builder.build();
}
return builder.build();
}

/**
* Normalizes after-agent callbacks.
*
* @param afterAgentCallback Callback list (sync or async).
* @return normalized async callbacks, or null if input is null.
* @return normalized async callbacks, or empty list if input is null.
*/
@CanIgnoreReturnValue
public static @Nullable ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
public static ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
List<AfterAgentCallbackBase> afterAgentCallback) {
if (afterAgentCallback == null) {
return null;
return ImmutableList.of();
} else if (afterAgentCallback.isEmpty()) {
return ImmutableList.of();
} else {
Expand Down
5 changes: 2 additions & 3 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,8 @@ public Optional<String> outputKey() {
return outputKey;
}

@Nullable
public BaseCodeExecutor codeExecutor() {
return codeExecutor.orElse(null);
public Optional<BaseCodeExecutor> codeExecutor() {
return codeExecutor;
}

public Model resolvedModel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,15 @@ private Flowable<Event> runOneStep(InvocationContext context) {
String agentToTransfer = event.actions().transferToAgent().get();
logger.debug("Transferring to agent: {}", agentToTransfer);
BaseAgent rootAgent = context.agent().rootAgent();
BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
if (nextAgent == null) {
Optional<BaseAgent> nextAgent = rootAgent.findAgent(agentToTransfer);
if (nextAgent.isEmpty()) {
String errorMsg = "Agent not found for transfer: " + agentToTransfer;
logger.error(errorMsg);
return postProcessedEvents.concatWith(
Flowable.error(new IllegalStateException(errorMsg)));
}
return postProcessedEvents.concatWith(
Flowable.defer(() -> nextAgent.runAsync(context)));
Flowable.defer(() -> nextAgent.get().runAsync(context)));
}
return postProcessedEvents;
});
Expand Down Expand Up @@ -574,14 +574,14 @@ public void onError(Throwable e) {
Flowable<Event> events = Flowable.just(event);
if (event.actions().transferToAgent().isPresent()) {
BaseAgent rootAgent = invocationContext.agent().rootAgent();
BaseAgent nextAgent =
Optional<BaseAgent> nextAgent =
rootAgent.findAgent(event.actions().transferToAgent().get());
if (nextAgent == null) {
if (nextAgent.isEmpty()) {
throw new IllegalStateException(
"Agent not found: " + event.actions().transferToAgent().get());
}
Flowable<Event> nextAgentEvents =
nextAgent.runLive(invocationContext);
nextAgent.get().runLive(invocationContext);
events = Flowable.concat(events, nextAgentEvents);
}
return events;
Expand Down
22 changes: 12 additions & 10 deletions core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ private static class CodeExecutionRequestProcessor implements RequestProcessor {
public Single<RequestProcessor.RequestProcessingResult> processRequest(
InvocationContext invocationContext, LlmRequest llmRequest) {
if (!(invocationContext.agent() instanceof LlmAgent llmAgent)
|| llmAgent.codeExecutor() == null) {
|| llmAgent.codeExecutor().isEmpty()) {
return Single.just(
RequestProcessor.RequestProcessingResult.create(llmRequest, ImmutableList.of()));
}

if (llmAgent.codeExecutor() instanceof BuiltInCodeExecutor builtInCodeExecutor) {
if (llmAgent.codeExecutor().get() instanceof BuiltInCodeExecutor builtInCodeExecutor) {
var llmRequestBuilder = llmRequest.toBuilder();
builtInCodeExecutor.processLlmRequest(llmRequestBuilder);
LlmRequest updatedLlmRequest = llmRequestBuilder.build();
Expand All @@ -124,8 +124,8 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
Flowable<Event> preprocessorEvents = runPreProcessor(invocationContext, llmRequest);

// Convert the code execution parts to text parts.
if (llmAgent.codeExecutor() != null) {
BaseCodeExecutor baseCodeExecutor = llmAgent.codeExecutor();
if (llmAgent.codeExecutor().isPresent()) {
BaseCodeExecutor baseCodeExecutor = llmAgent.codeExecutor().get();
List<Content> updatedContents = new ArrayList<>();
for (Content content : llmRequest.contents()) {
List<String> delimiters =
Expand Down Expand Up @@ -173,10 +173,11 @@ private static Flowable<Event> runPreProcessor(
return Flowable.empty();
}

var codeExecutor = llmAgent.codeExecutor();
if (codeExecutor == null) {
var codeExecutorOptional = llmAgent.codeExecutor();
if (codeExecutorOptional.isEmpty()) {
return Flowable.empty();
}
var codeExecutor = codeExecutorOptional.get();

if (codeExecutor instanceof BuiltInCodeExecutor) {
return Flowable.empty();
Expand Down Expand Up @@ -268,10 +269,11 @@ private static Flowable<Event> runPostProcessor(
if (!(invocationContext.agent() instanceof LlmAgent llmAgent)) {
return Flowable.empty();
}
var codeExecutor = llmAgent.codeExecutor();
if (codeExecutor == null) {
var codeExecutorOptional = llmAgent.codeExecutor();
if (codeExecutorOptional.isEmpty()) {
return Flowable.empty();
}
var codeExecutor = codeExecutorOptional.get();
if (llmResponse.content().isEmpty()) {
return Flowable.empty();
}
Expand Down Expand Up @@ -387,8 +389,8 @@ private static List<File> extractAndReplaceInlineFiles(
private static Optional<String> getOrSetExecutionId(
InvocationContext invocationContext, CodeExecutorContext codeExecutorContext) {
if (!(invocationContext.agent() instanceof LlmAgent llmAgent)
|| llmAgent.codeExecutor() == null
|| !llmAgent.codeExecutor().stateful()) {
|| llmAgent.codeExecutor().isEmpty()
|| !llmAgent.codeExecutor().get().stateful()) {
return Optional.empty();
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -768,14 +768,14 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) {
return rootAgent;
}

BaseAgent agent = rootAgent.findSubAgent(author);
Optional<BaseAgent> agent = rootAgent.findSubAgent(author);

if (agent == null) {
if (agent.isEmpty()) {
continue;
}

if (this.isTransferableAcrossAgentTree(agent)) {
return agent;
if (this.isTransferableAcrossAgentTree(agent.get())) {
return agent.get();
}
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/test/java/com/google/adk/agents/BaseAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ public void findAgent_returnsCorrectAgent() {
TestBaseAgent agent =
new TestBaseAgent(
TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, null, ImmutableList.of(subAgent), null, null);
assertThat(agent.findAgent("subSubAgent")).isEqualTo(subSubAgent);
assertThat(agent.findAgent("subAgent")).isEqualTo(subAgent);
assertThat(agent.findAgent(TEST_AGENT_NAME)).isEqualTo(agent);
assertThat(agent.findAgent("nonExistent")).isNull();
assertThat(agent.findAgent("subSubAgent")).hasValue(subSubAgent);
assertThat(agent.findAgent("subAgent")).hasValue(subAgent);
assertThat(agent.findAgent(TEST_AGENT_NAME)).hasValue(agent);
assertThat(agent.findAgent("nonExistent")).isEmpty();
}

@Test
Expand Down
Loading