/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.spec;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpNotificationHandler;
import io.modelcontextprotocol.server.McpRequestHandler;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.MissingMcpTransportSession;
import io.modelcontextprotocol.util.Assert;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

public class McpStreamableServerSession
implements McpLoggableSession {
    private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class);
    private final ConcurrentHashMap<Object, McpStreamableServerSessionStream> requestIdToStream = new ConcurrentHashMap();
    private final String id;
    private final Duration requestTimeout;
    private final AtomicLong requestCounter = new AtomicLong(0L);
    private final Map<String, McpRequestHandler<?>> requestHandlers;
    private final Map<String, McpNotificationHandler> notificationHandlers;
    private final AtomicReference<McpSchema.ClientCapabilities> clientCapabilities = new AtomicReference();
    private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference();
    private final AtomicReference<McpLoggableSession> listeningStreamRef;
    private final MissingMcpTransportSession missingMcpTransportSession;
    private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO;

    public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, Duration requestTimeout, Map<String, McpRequestHandler<?>> requestHandlers, Map<String, McpNotificationHandler> notificationHandlers) {
        this.id = id;
        this.missingMcpTransportSession = new MissingMcpTransportSession(id);
        this.listeningStreamRef = new AtomicReference<MissingMcpTransportSession>(this.missingMcpTransportSession);
        this.clientCapabilities.lazySet(clientCapabilities);
        this.clientInfo.lazySet(clientInfo);
        this.requestTimeout = requestTimeout;
        this.requestHandlers = requestHandlers;
        this.notificationHandlers = notificationHandlers;
    }

    @Override
    public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) {
        Assert.notNull((Object)minLoggingLevel, "minLoggingLevel must not be null");
        this.minLoggingLevel = minLoggingLevel;
    }

    @Override
    public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) {
        return loggingLevel.level() >= this.minLoggingLevel.level();
    }

    public String getId() {
        return this.id;
    }

    private String generateRequestId() {
        return this.id + "-" + this.requestCounter.getAndIncrement();
    }

    @Override
    public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
        return Mono.defer(() -> {
            McpLoggableSession listeningStream = this.listeningStreamRef.get();
            return listeningStream.sendRequest(method, requestParams, typeRef);
        });
    }

    @Override
    public Mono<Void> sendNotification(String method, Object params) {
        return Mono.defer(() -> {
            McpLoggableSession listeningStream = this.listeningStreamRef.get();
            return listeningStream.sendNotification(method, params);
        });
    }

    public Mono<Void> delete() {
        return this.closeGracefully().then(Mono.fromRunnable(() -> {}));
    }

    public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) {
        McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport);
        this.listeningStreamRef.set(listeningStream);
        return listeningStream;
    }

    public Flux<McpSchema.JSONRPCMessage> replay(Object lastEventId) {
        return Flux.empty();
    }

    public Mono<Void> responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) {
        return Mono.deferContextual(ctx -> {
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport);
            McpRequestHandler<?> requestHandler = this.requestHandlers.get(jsonrpcRequest.method());
            if (requestHandler == null) {
                MethodNotFoundError error = this.getMethodNotFoundError(jsonrpcRequest.method());
                return transport.sendMessage(new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32601, error.message(), error.data())));
            }
            return requestHandler.handle(new McpAsyncServerExchange(this.id, stream, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), jsonrpcRequest.params()).map(result -> new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.id(), result, null)).onErrorResume(e -> {
                McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse("2.0", jsonrpcRequest.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32603, e.getMessage(), null));
                return Mono.just((Object)errorResponse);
            }).flatMap(transport::sendMessage).then(transport.closeGracefully());
        });
    }

    public Mono<Void> accept(McpSchema.JSONRPCNotification notification) {
        return Mono.deferContextual(ctx -> {
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method());
            if (notificationHandler == null) {
                logger.warn("No handler registered for notification method: {}", (Object)notification);
                return Mono.empty();
            }
            McpLoggableSession listeningStream = this.listeningStreamRef.get();
            return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params());
        });
    }

    public Mono<Void> accept(McpSchema.JSONRPCResponse response) {
        return Mono.defer(() -> {
            logger.debug("Received response: {}", (Object)response);
            if (response.id() != null) {
                McpStreamableServerSessionStream stream = this.requestIdToStream.get(response.id());
                if (stream == null) {
                    return Mono.error((Throwable)McpError.builder(-32603).message("Unexpected response for unknown id " + String.valueOf(response.id())).build());
                }
                MonoSink<McpSchema.JSONRPCResponse> sink = stream.pendingResponses.remove(response.id());
                if (sink == null) {
                    return Mono.error((Throwable)McpError.builder(-32603).message("Unexpected response for unknown id " + String.valueOf(response.id())).build());
                }
                sink.success((Object)response);
            } else {
                logger.error("Discarded MCP request response without session id. This is an indication of a bug in the request sender code that can lead to memory leaks as pending requests will never be completed.");
            }
            return Mono.empty();
        });
    }

    private MethodNotFoundError getMethodNotFoundError(String method) {
        return new MethodNotFoundError(method, "Method not found: " + method, null);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.defer(() -> {
            McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(this.missingMcpTransportSession);
            return listeningStream.closeGracefully();
        });
    }

    @Override
    public void close() {
        McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(this.missingMcpTransportSession);
        if (listeningStream != null) {
            listeningStream.close();
        }
    }

    public final class McpStreamableServerSessionStream
    implements McpLoggableSession {
        private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap();
        private final McpStreamableServerTransport transport;
        private final String transportId;
        private final Supplier<String> uuidGenerator;

        public McpStreamableServerSessionStream(McpStreamableServerTransport transport) {
            this.transport = transport;
            this.transportId = UUID.randomUUID().toString();
            this.uuidGenerator = () -> this.transportId + "_" + String.valueOf(UUID.randomUUID());
        }

        @Override
        public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) {
            Assert.notNull((Object)minLoggingLevel, "minLoggingLevel must not be null");
            McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel);
        }

        @Override
        public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) {
            return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel);
        }

        @Override
        public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
            String requestId = McpStreamableServerSession.this.generateRequestId();
            McpStreamableServerSession.this.requestIdToStream.put(requestId, this);
            return Mono.create(sink -> {
                this.pendingResponses.put(requestId, (MonoSink<McpSchema.JSONRPCResponse>)sink);
                McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest("2.0", method, requestId, requestParams);
                String messageId = this.uuidGenerator.get();
                this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> {}, arg_0 -> ((MonoSink)sink).error(arg_0));
            }).timeout(McpStreamableServerSession.this.requestTimeout).doOnError(e -> {
                this.pendingResponses.remove(requestId);
                McpStreamableServerSession.this.requestIdToStream.remove(requestId);
            }).handle((jsonRpcResponse, sink) -> {
                if (jsonRpcResponse.error() != null) {
                    sink.error((Throwable)new McpError(jsonRpcResponse.error()));
                } else if (typeRef.getType().equals(Void.class)) {
                    sink.complete();
                } else {
                    sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef));
                }
            });
        }

        @Override
        public Mono<Void> sendNotification(String method, Object params) {
            McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification("2.0", method, params);
            String messageId = this.uuidGenerator.get();
            return this.transport.sendMessage(jsonrpcNotification, messageId);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.defer(() -> {
                this.pendingResponses.values().forEach(s -> s.error((Throwable)new RuntimeException("Stream closed")));
                this.pendingResponses.clear();
                McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, McpStreamableServerSession.this.missingMcpTransportSession);
                McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals);
                return this.transport.closeGracefully();
            });
        }

        @Override
        public void close() {
            this.pendingResponses.values().forEach(s -> s.error((Throwable)new RuntimeException("Stream closed")));
            this.pendingResponses.clear();
            McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, McpStreamableServerSession.this.missingMcpTransportSession);
            McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals);
            this.transport.close();
        }
    }

    record MethodNotFoundError(String method, String message, Object data) {
    }

    public record McpStreamableServerSessionInit(McpStreamableServerSession session, Mono<McpSchema.InitializeResult> initResult) {
    }

    public static interface Factory {
        public McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest var1);
    }

    public static interface InitRequestHandler {
        public Mono<McpSchema.InitializeResult> handle(McpSchema.InitializeRequest var1);
    }
}

