mirror of
https://github.com/blacktop/ipsw.git
synced 2026-05-08 12:22:26 +00:00
fix: ACP model selection diagnostics
This commit is contained in:
+115
-42
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -17,6 +18,8 @@ import (
|
||||
"github.com/blacktop/ipsw/internal/ai/utils"
|
||||
)
|
||||
|
||||
const maxACPStderrBytes = 4096
|
||||
|
||||
type Config struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
@@ -36,6 +39,8 @@ type ACP struct {
|
||||
models map[string]string
|
||||
}
|
||||
|
||||
var discardACPLogger = slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
|
||||
func New(ctx context.Context, conf *Config) (*ACP, error) {
|
||||
if conf == nil {
|
||||
return nil, fmt.Errorf("acp: config is nil")
|
||||
@@ -96,6 +101,7 @@ func (c *ACP) fetchModelsViaACP() (map[string]string, error) {
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("acp: failed to start agent command '%s': %w", c.conf.Command, err)
|
||||
}
|
||||
stderrCapture := copyACPStderr(stderr, c.conf.Verbose)
|
||||
defer func() {
|
||||
_ = stdin.Close()
|
||||
_ = stdout.Close()
|
||||
@@ -106,27 +112,19 @@ func (c *ACP) fetchModelsViaACP() (map[string]string, error) {
|
||||
_ = cmd.Wait()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if c.conf.Verbose {
|
||||
_, _ = io.Copy(os.Stderr, stderr)
|
||||
} else {
|
||||
_, _ = io.Copy(io.Discard, stderr)
|
||||
}
|
||||
}()
|
||||
|
||||
cwd, _, conn, err := newClientConnection(stdin, stdout)
|
||||
clientConn, err := newClientConnection(stdin, stdout, c.conf.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := initializeClient(ctx, conn); err != nil {
|
||||
return nil, err
|
||||
if err := initializeClient(ctx, clientConn.conn); err != nil {
|
||||
return nil, stderrCapture.wrap(err)
|
||||
}
|
||||
sess, err := conn.NewSession(ctx, acp.NewSessionRequest{
|
||||
Cwd: cwd,
|
||||
sess, err := clientConn.conn.NewSession(ctx, acp.NewSessionRequest{
|
||||
Cwd: clientConn.cwd,
|
||||
McpServers: []acp.McpServer{},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acp: newSession failed: %w", err)
|
||||
return nil, stderrCapture.wrap(fmt.Errorf("acp: newSession failed: %w", err))
|
||||
}
|
||||
|
||||
models := make(map[string]string)
|
||||
@@ -159,10 +157,13 @@ func (c *ACP) SetModel(model string) error {
|
||||
if c.models == nil {
|
||||
c.models = make(map[string]string)
|
||||
}
|
||||
// ACP agents/adapters may accept arbitrary model IDs without a prior list.
|
||||
if _, ok := c.models[model]; !ok {
|
||||
c.models[model] = model
|
||||
if modelID := strings.TrimSpace(c.models[model]); modelID != "" {
|
||||
c.models[modelID] = modelID
|
||||
c.conf.Model = modelID
|
||||
return nil
|
||||
}
|
||||
// ACP agents/adapters may accept arbitrary model IDs without a prior list.
|
||||
c.models[model] = model
|
||||
c.conf.Model = model
|
||||
return nil
|
||||
}
|
||||
@@ -261,14 +262,92 @@ func (c *collectingClient) WaitForTerminalExit(ctx context.Context, _ acp.WaitFo
|
||||
return acp.WaitForTerminalExitResponse{}, fmt.Errorf("acp client: terminal not supported")
|
||||
}
|
||||
|
||||
func newClientConnection(stdin io.Writer, stdout io.Reader) (string, *collectingClient, *acp.ClientSideConnection, error) {
|
||||
type acpStderrCapture struct {
|
||||
tail *lockedTailBuffer
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
type lockedTailBuffer struct {
|
||||
mu sync.Mutex
|
||||
limit int
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newLockedTailBuffer(limit int) *lockedTailBuffer {
|
||||
return &lockedTailBuffer{limit: limit}
|
||||
}
|
||||
|
||||
func (b *lockedTailBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.buf = append(b.buf, p...)
|
||||
if len(b.buf) > b.limit {
|
||||
b.buf = b.buf[len(b.buf)-b.limit:]
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (b *lockedTailBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return strings.TrimSpace(string(b.buf))
|
||||
}
|
||||
|
||||
func copyACPStderr(stderr io.Reader, verbose bool) *acpStderrCapture {
|
||||
tail := newLockedTailBuffer(maxACPStderrBytes)
|
||||
capture := &acpStderrCapture{
|
||||
tail: tail,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
dst := io.Writer(tail)
|
||||
if verbose {
|
||||
dst = io.MultiWriter(os.Stderr, tail)
|
||||
}
|
||||
go func() {
|
||||
defer close(capture.done)
|
||||
_, _ = io.Copy(dst, stderr)
|
||||
}()
|
||||
return capture
|
||||
}
|
||||
|
||||
func (c *acpStderrCapture) String() string {
|
||||
if c == nil || c.tail == nil {
|
||||
return ""
|
||||
}
|
||||
select {
|
||||
case <-c.done:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
return c.tail.String()
|
||||
}
|
||||
|
||||
func (c *acpStderrCapture) wrap(err error) error {
|
||||
if c == nil {
|
||||
return err
|
||||
}
|
||||
if stderr := c.String(); stderr != "" {
|
||||
return fmt.Errorf("%w\nagent stderr:\n%s", err, stderr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type clientConnection struct {
|
||||
cwd string
|
||||
client *collectingClient
|
||||
conn *acp.ClientSideConnection
|
||||
}
|
||||
|
||||
func newClientConnection(stdin io.Writer, stdout io.Reader, verbose bool) (*clientConnection, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", nil, nil, fmt.Errorf("acp: failed to get current working directory: %w", err)
|
||||
return nil, fmt.Errorf("acp: failed to get current working directory: %w", err)
|
||||
}
|
||||
client := &collectingClient{cwd: cwd}
|
||||
conn := acp.NewClientSideConnection(client, stdin, stdout)
|
||||
return cwd, client, conn, nil
|
||||
if !verbose {
|
||||
conn.SetLogger(discardACPLogger)
|
||||
}
|
||||
return &clientConnection{cwd: cwd, client: client, conn: conn}, nil
|
||||
}
|
||||
|
||||
func initializeClient(ctx context.Context, conn *acp.ClientSideConnection) error {
|
||||
@@ -345,10 +424,7 @@ func readTextWindow(content string, line, limit *int) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
start := 0
|
||||
if line != nil && *line > 1 {
|
||||
start = *line - 1
|
||||
if start > len(lines) {
|
||||
start = len(lines)
|
||||
}
|
||||
start = min(*line-1, len(lines))
|
||||
}
|
||||
end := len(lines)
|
||||
if limit != nil {
|
||||
@@ -400,6 +476,7 @@ func (c *ACP) Chat() (string, error) {
|
||||
if err := cmd.Start(); err != nil {
|
||||
return "", fmt.Errorf("acp: failed to start agent command '%s': %w", c.conf.Command, err)
|
||||
}
|
||||
stderrCapture := copyACPStderr(stderr, c.conf.Verbose)
|
||||
defer func() {
|
||||
_ = stdin.Close()
|
||||
_ = stdout.Close()
|
||||
@@ -410,48 +487,44 @@ func (c *ACP) Chat() (string, error) {
|
||||
_ = cmd.Wait()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if c.conf.Verbose {
|
||||
_, _ = io.Copy(os.Stderr, stderr)
|
||||
} else {
|
||||
_, _ = io.Copy(io.Discard, stderr)
|
||||
}
|
||||
}()
|
||||
|
||||
cwd, client, conn, err := newClientConnection(stdin, stdout)
|
||||
clientConn, err := newClientConnection(stdin, stdout, c.conf.Verbose)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := initializeClient(ctx, conn); err != nil {
|
||||
return "", err
|
||||
if err := initializeClient(ctx, clientConn.conn); err != nil {
|
||||
return "", stderrCapture.wrap(err)
|
||||
}
|
||||
|
||||
sess, err := conn.NewSession(ctx, acp.NewSessionRequest{
|
||||
Cwd: cwd,
|
||||
sess, err := clientConn.conn.NewSession(ctx, acp.NewSessionRequest{
|
||||
Cwd: clientConn.cwd,
|
||||
McpServers: []acp.McpServer{},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("acp: newSession failed: %w", err)
|
||||
return "", stderrCapture.wrap(fmt.Errorf("acp: newSession failed: %w", err))
|
||||
}
|
||||
|
||||
// Best-effort: attempt to set model if the agent supports it (UNSTABLE ACP API).
|
||||
if m := strings.TrimSpace(c.conf.Model); m != "" && m != "default" {
|
||||
if _, setErr := conn.UnstableSetSessionModel(ctx, acp.UnstableSetSessionModelRequest{SessionId: sess.SessionId, ModelId: acp.UnstableModelId(m)}); setErr != nil {
|
||||
setModelReq := acp.UnstableSetSessionModelRequest{
|
||||
SessionId: sess.SessionId,
|
||||
ModelId: acp.UnstableModelId(m),
|
||||
}
|
||||
if _, setErr := clientConn.conn.UnstableSetSessionModel(ctx, setModelReq); setErr != nil {
|
||||
if c.conf.Verbose {
|
||||
log.Debugf("acp: UnstableSetSessionModel failed (ignored): %v", setErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
promptResp, err := conn.Prompt(ctx, acp.PromptRequest{
|
||||
promptResp, err := clientConn.conn.Prompt(ctx, acp.PromptRequest{
|
||||
SessionId: sess.SessionId,
|
||||
Prompt: []acp.ContentBlock{acp.TextBlock(c.conf.Prompt)},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("acp: prompt failed: %w", err)
|
||||
return "", stderrCapture.wrap(fmt.Errorf("acp: prompt failed: %w", err))
|
||||
}
|
||||
|
||||
resp := strings.TrimSpace(client.String())
|
||||
resp := strings.TrimSpace(clientConn.client.String())
|
||||
if promptResp.StopReason != "" && promptResp.StopReason != acp.StopReasonEndTurn {
|
||||
return "", promptStopError(promptResp.StopReason, resp)
|
||||
}
|
||||
|
||||
@@ -74,6 +74,90 @@ func TestACPChatReportsNonEndTurnStopReason(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestACPChatIncludesAgentStderrOnInitializeFailure(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := newTestACP(t, ctx, "stderr-exit", "")
|
||||
_, err := client.Chat()
|
||||
if err == nil {
|
||||
t.Fatal("Chat() error = nil, want initialize failure")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "agent stderr:") {
|
||||
t.Fatalf("Chat() error = %q, want agent stderr section", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "test acp adapter failed before initialize") {
|
||||
t.Fatalf("Chat() error = %q, want adapter stderr", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACPSetModelResolvesDisplayNameToModelID(t *testing.T) {
|
||||
client, err := New(context.Background(), &Config{Command: "test-agent"})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
if _, err := client.SetModels(map[string]string{"Default (recommended)": "default"}); err != nil {
|
||||
t.Fatalf("SetModels() error = %v", err)
|
||||
}
|
||||
|
||||
if err := client.SetModel("Default (recommended)"); err != nil {
|
||||
t.Fatalf("SetModel() error = %v", err)
|
||||
}
|
||||
if client.conf.Model != "default" {
|
||||
t.Fatalf("conf.Model = %q, want default", client.conf.Model)
|
||||
}
|
||||
if got := client.models["default"]; got != "default" {
|
||||
t.Fatalf("models[default] = %q, want default", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACPSetModelAllowsArbitraryModelID(t *testing.T) {
|
||||
client, err := New(context.Background(), &Config{Command: "test-agent"})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if err := client.SetModel("custom-model"); err != nil {
|
||||
t.Fatalf("SetModel() error = %v", err)
|
||||
}
|
||||
if client.conf.Model != "custom-model" {
|
||||
t.Fatalf("conf.Model = %q, want custom-model", client.conf.Model)
|
||||
}
|
||||
if got := client.models["custom-model"]; got != "custom-model" {
|
||||
t.Fatalf("models[custom-model] = %q, want custom-model", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockedTailBufferKeepsBoundedStderrTail(t *testing.T) {
|
||||
buf := newLockedTailBuffer(5)
|
||||
if _, err := buf.Write([]byte("hello")); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
if _, err := buf.Write([]byte(" world")); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
if got := buf.String(); got != "world" {
|
||||
t.Fatalf("String() = %q, want world", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACPStderrCaptureWaitsForCopyToFinish(t *testing.T) {
|
||||
buf := newLockedTailBuffer(32)
|
||||
capture := &acpStderrCapture{
|
||||
tail: buf,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_, _ = buf.Write([]byte("delayed stderr"))
|
||||
close(capture.done)
|
||||
}()
|
||||
|
||||
if got := capture.String(); got != "delayed stderr" {
|
||||
t.Fatalf("String() = %q, want delayed stderr", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectingClientReadTextFileRestrictsCWD(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
inside := filepath.Join(dir, "inside.txt")
|
||||
@@ -181,6 +265,10 @@ func TestACPAgentHelperProcess(t *testing.T) {
|
||||
if mode == "" {
|
||||
return
|
||||
}
|
||||
if mode == "stderr-exit" {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "test acp adapter failed before initialize")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
agent := &testACPAgent{
|
||||
mode: mode,
|
||||
|
||||
Reference in New Issue
Block a user