⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client {
if options != nil {
opts = *options
}
options = nil // prevent reuse

if opts.Logger == nil { // ensure we have a logger
opts.Logger = ensureLogger(nil)
Expand Down Expand Up @@ -129,6 +128,7 @@ type ClientOptions struct {
PromptListChangedHandler func(context.Context, *PromptListChangedRequest)
ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest)
ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest)
TaskStatusHandler func(context.Context, *TaskStatusNotificationRequest)
LoggingMessageHandler func(context.Context, *LoggingMessageRequest)
ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest)
// If non-zero, defines an interval for regular "ping" requests.
Expand Down Expand Up @@ -807,6 +807,7 @@ var clientMethodInfos = map[string]methodInfo{
methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0),
methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK),
notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK),
notificationTaskStatus: newClientMethodInfo(clientMethod((*Client).callTaskStatusHandler), notification),
notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK),
notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK),
notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK),
Expand Down Expand Up @@ -888,6 +889,9 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams)
//
// The params.Arguments can be any value that marshals into a JSON object.
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) {
if params != nil && params.Task != nil {
return nil, fmt.Errorf("task augmentation requested: use CallToolTask")
}
if params == nil {
params = new(CallToolParams)
}
Expand All @@ -898,6 +902,43 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
}

// CallToolTask calls a tool using task-based execution (tools/call with params.task).
//
// The response is a CreateTaskResult. Use GetTask to poll for task state and
// TaskResult to retrieve the final tool result.
func (cs *ClientSession) CallToolTask(ctx context.Context, params *CallToolParams) (*CreateTaskResult, error) {
if params == nil || params.Task == nil {
return nil, fmt.Errorf("CallToolTask requires params.Task")
}
if params.Arguments == nil {
// Avoid sending nil over the wire.
params.Arguments = map[string]any{}
}
return handleSend[*CreateTaskResult](ctx, methodCallTool, newClientRequest(cs, params))
}

// GetTask polls task status via tasks/get.
func (cs *ClientSession) GetTask(ctx context.Context, params *GetTaskParams) (*GetTaskResult, error) {
return handleSend[*GetTaskResult](ctx, methodGetTask, newClientRequest(cs, orZero[Params](params)))
}

// ListTasks lists tasks via tasks/list.
func (cs *ClientSession) ListTasks(ctx context.Context, params *ListTasksParams) (*ListTasksResult, error) {
return handleSend[*ListTasksResult](ctx, methodListTasks, newClientRequest(cs, orZero[Params](params)))
}

// CancelTask requests cancellation via tasks/cancel.
func (cs *ClientSession) CancelTask(ctx context.Context, params *CancelTaskParams) (*CancelTaskResult, error) {
return handleSend[*CancelTaskResult](ctx, methodCancelTask, newClientRequest(cs, orZero[Params](params)))
}

// TaskResult retrieves the final result of a task via tasks/result.
//
// Currently, this SDK supports tasks/result only for tasks created from tools/call.
func (cs *ClientSession) TaskResult(ctx context.Context, params *TaskResultParams) (*CallToolResult, error) {
return handleSend[*CallToolResult](ctx, methodTaskResult, newClientRequest(cs, orZero[Params](params)))
}

func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error {
_, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params)))
return err
Expand Down Expand Up @@ -971,6 +1012,13 @@ func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequ
return nil, nil
}

func (c *Client) callTaskStatusHandler(ctx context.Context, req *TaskStatusNotificationRequest) (Result, error) {
if h := c.opts.TaskStatusHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) {
if h := cs.client.opts.ProgressNotificationHandler; h != nil {
h(ctx, clientRequestFor(cs, params))
Expand Down
168 changes: 168 additions & 0 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ type CallToolParams struct {
// Arguments holds the tool arguments. It can hold any value that can be
// marshaled to JSON.
Arguments any `json:"arguments,omitempty"`
// Task optionally requests task-based execution of this tool call.
//
// Note: when Task is present, the wire response is a CreateTaskResult rather
// than a CallToolResult.
Task *TaskParams `json:"task,omitempty"`
}

// CallToolParamsRaw is passed to tool handlers on the server. Its arguments
Expand All @@ -63,6 +68,8 @@ type CallToolParamsRaw struct {
// is the responsibility of the tool handler to unmarshal and validate the
// Arguments (see [AddTool]).
Arguments json.RawMessage `json:"arguments,omitempty"`
// Task optionally requests task-based execution of this tool call.
Task *TaskParams `json:"task,omitempty"`
}

// A CallToolResult is the server's response to a tool call.
Expand Down Expand Up @@ -210,13 +217,16 @@ type ClientCapabilities struct {
Sampling *SamplingCapabilities `json:"sampling,omitempty"`
// Elicitation is present if the client supports elicitation from the server.
Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"`
// Tasks describes support for task-based execution.
Tasks *TasksCapabilities `json:"tasks,omitempty"`
}

// clone returns a deep copy of the ClientCapabilities.
func (c *ClientCapabilities) clone() *ClientCapabilities {
cp := *c
cp.RootsV2 = shallowClone(c.RootsV2)
cp.Sampling = shallowClone(c.Sampling)
cp.Tasks = shallowClone(c.Tasks)
if c.Elicitation != nil {
x := *c.Elicitation
x.Form = shallowClone(c.Elicitation.Form)
Expand Down Expand Up @@ -1092,8 +1102,28 @@ type Tool struct {
Title string `json:"title,omitempty"`
// Icons for the tool, if any.
Icons []Icon `json:"icons,omitempty"`
// Execution contains optional execution-related settings.
Execution *ToolExecution `json:"execution,omitempty"`
}

// ToolExecution configures execution behavior for a tool.
type ToolExecution struct {
// TaskSupport declares task support for this tool.
//
// Valid values are: "required", "optional", or "forbidden".
// See ToolTaskSupportRequired, ToolTaskSupportOptional, and ToolTaskSupportForbidden.
TaskSupport string `json:"taskSupport,omitempty"`
}

const (
// ToolTaskSupportRequired indicates the tool MUST be invoked with task augmentation.
ToolTaskSupportRequired = "required"
// ToolTaskSupportOptional indicates the tool MAY be invoked with task augmentation.
ToolTaskSupportOptional = "optional"
// ToolTaskSupportForbidden indicates the tool MUST NOT be invoked with task augmentation.
ToolTaskSupportForbidden = "forbidden"
)

// Additional properties describing a Tool to clients.
//
// NOTE: all properties in ToolAnnotations are hints. They are not
Expand Down Expand Up @@ -1314,6 +1344,8 @@ type ServerCapabilities struct {
Resources *ResourceCapabilities `json:"resources,omitempty"`
// Tools is present if the supports tools.
Tools *ToolCapabilities `json:"tools,omitempty"`
// Tasks describes support for task-based execution.
Tasks *TasksCapabilities `json:"tasks,omitempty"`
}

// clone returns a deep copy of the ServerCapabilities.
Expand All @@ -1324,12 +1356,148 @@ func (c *ServerCapabilities) clone() *ServerCapabilities {
cp.Prompts = shallowClone(c.Prompts)
cp.Resources = shallowClone(c.Resources)
cp.Tools = shallowClone(c.Tools)
cp.Tasks = shallowClone(c.Tasks)
return &cp
}

// TasksCapabilities describes support for task-based execution.
type TasksCapabilities struct {
List *TasksListCapabilities `json:"list,omitempty"`
Cancel *TasksCancelCapabilities `json:"cancel,omitempty"`
Requests *TasksRequestsCapabilities `json:"requests,omitempty"`
}

type TasksListCapabilities struct{}
type TasksCancelCapabilities struct{}

type TasksRequestsCapabilities struct {
Tools *TasksToolsRequestCapabilities `json:"tools,omitempty"`
Sampling *TasksSamplingRequestCapabilities `json:"sampling,omitempty"`
Elicitation *TasksElicitationRequestCapabilities `json:"elicitation,omitempty"`
}

type TasksToolsRequestCapabilities struct {
Call *TasksToolsCallCapabilities `json:"call,omitempty"`
}

type TasksToolsCallCapabilities struct{}

type TasksSamplingRequestCapabilities struct {
CreateMessage *TasksSamplingCreateMessageCapabilities `json:"createMessage,omitempty"`
}

type TasksSamplingCreateMessageCapabilities struct{}

type TasksElicitationRequestCapabilities struct {
Create *TasksElicitationCreateCapabilities `json:"create,omitempty"`
}

type TasksElicitationCreateCapabilities struct{}

// TaskParams is included in request parameters to request task-based execution.
type TaskParams struct {
TTL *int64 `json:"ttl,omitempty"`
}

type TaskStatus string

const (
TaskStatusWorking TaskStatus = "working"
TaskStatusInputRequired TaskStatus = "input_required"
TaskStatusCompleted TaskStatus = "completed"
TaskStatusFailed TaskStatus = "failed"
TaskStatusCancelled TaskStatus = "cancelled"
)

// Task describes the state of a task.
type Task struct {
Meta `json:"_meta,omitempty"`
TaskID string `json:"taskId"`
Status TaskStatus `json:"status"`
StatusMessage string `json:"statusMessage,omitempty"`
CreatedAt string `json:"createdAt"`
LastUpdatedAt string `json:"lastUpdatedAt"`
TTL *int64 `json:"ttl"`
PollInterval *int64 `json:"pollInterval,omitempty"`
}

// CreateTaskResult is returned for task-augmented requests.
type CreateTaskResult struct {
Meta `json:"_meta,omitempty"`
Task *Task `json:"task"`
}

func (*CreateTaskResult) isResult() {}

type GetTaskParams struct {
Meta `json:"_meta,omitempty"`
TaskID string `json:"taskId"`
}

func (*GetTaskParams) isParams() {}
func (x *GetTaskParams) GetProgressToken() any { return getProgressToken(x) }
func (x *GetTaskParams) SetProgressToken(t any) { setProgressToken(x, t) }

type GetTaskResult Task

func (*GetTaskResult) isResult() {}

type ListTasksParams struct {
Meta `json:"_meta,omitempty"`
Cursor string `json:"cursor,omitempty"`
}

func (x *ListTasksParams) isParams() {}
func (x *ListTasksParams) GetProgressToken() any { return getProgressToken(x) }
func (x *ListTasksParams) SetProgressToken(t any) { setProgressToken(x, t) }
func (x *ListTasksParams) cursorPtr() *string { return &x.Cursor }

type ListTasksResult struct {
Meta `json:"_meta,omitempty"`
Tasks []*Task `json:"tasks"`
NextCursor string `json:"nextCursor,omitempty"`
}

func (*ListTasksResult) isResult() {}
func (x *ListTasksResult) nextCursorPtr() *string { return &x.NextCursor }

type CancelTaskParams struct {
Meta `json:"_meta,omitempty"`
TaskID string `json:"taskId"`
}

func (*CancelTaskParams) isParams() {}
func (x *CancelTaskParams) GetProgressToken() any { return getProgressToken(x) }
func (x *CancelTaskParams) SetProgressToken(t any) { setProgressToken(x, t) }

type CancelTaskResult Task

func (*CancelTaskResult) isResult() {}

type TaskResultParams struct {
Meta `json:"_meta,omitempty"`
TaskID string `json:"taskId"`
}

func (*TaskResultParams) isParams() {}
func (x *TaskResultParams) GetProgressToken() any { return getProgressToken(x) }
func (x *TaskResultParams) SetProgressToken(t any) { setProgressToken(x, t) }

// TaskStatusNotificationParams is sent as notifications/tasks/status.
type TaskStatusNotificationParams Task

func (*TaskStatusNotificationParams) isParams() {}
func (x *TaskStatusNotificationParams) GetProgressToken() any { return getProgressToken(x) }
func (x *TaskStatusNotificationParams) SetProgressToken(t any) { setProgressToken(x, t) }

const (
methodCallTool = "tools/call"
methodGetTask = "tasks/get"
methodListTasks = "tasks/list"
methodCancelTask = "tasks/cancel"
methodTaskResult = "tasks/result"
notificationCancelled = "notifications/cancelled"
notificationTaskStatus = "notifications/tasks/status"
methodComplete = "completion/complete"
methodCreateMessage = "sampling/createMessage"
methodElicit = "elicitation/create"
Expand Down
6 changes: 6 additions & 0 deletions mcp/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ package mcp

type (
CallToolRequest = ServerRequest[*CallToolParamsRaw]
CancelTaskRequest = ServerRequest[*CancelTaskParams]
CompleteRequest = ServerRequest[*CompleteParams]
GetTaskRequest = ServerRequest[*GetTaskParams]
GetPromptRequest = ServerRequest[*GetPromptParams]
InitializedRequest = ServerRequest[*InitializedParams]
ListTasksRequest = ServerRequest[*ListTasksParams]
ListPromptsRequest = ServerRequest[*ListPromptsParams]
ListResourcesRequest = ServerRequest[*ListResourcesParams]
ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams]
Expand All @@ -19,6 +22,8 @@ type (
ReadResourceRequest = ServerRequest[*ReadResourceParams]
RootsListChangedRequest = ServerRequest[*RootsListChangedParams]
SubscribeRequest = ServerRequest[*SubscribeParams]
TaskStatusNotificationServerRequest = ServerRequest[*TaskStatusNotificationParams]
TaskResultRequest = ServerRequest[*TaskResultParams]
UnsubscribeRequest = ServerRequest[*UnsubscribeParams]
)

Expand All @@ -33,6 +38,7 @@ type (
PromptListChangedRequest = ClientRequest[*PromptListChangedParams]
ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams]
ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams]
TaskStatusNotificationRequest = ClientRequest[*TaskStatusNotificationParams]
ToolListChangedRequest = ClientRequest[*ToolListChangedParams]
ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams]
)
Loading