/*
 * Copyright (C) 2021 Apple Inc. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "config.h"
#include "WebBroadcastChannelRegistry.h"

#include "NetworkBroadcastChannelRegistryMessages.h"
#include "NetworkProcessConnection.h"
#include "WebProcess.h"
#include <WebCore/BroadcastChannel.h>
#include <WebCore/MessageWithMessagePorts.h>
#include <wtf/CallbackAggregator.h>

namespace WebKit {

static inline IPC::Connection& networkProcessConnection()
{
    return WebProcess::singleton().ensureNetworkProcessConnection().connection();
}

void WebBroadcastChannelRegistry::registerChannel(const WebCore::ClientOrigin& origin, const String& name, WebCore::BroadcastChannelIdentifier identifier)
{
    auto& channelsForOrigin = m_channelsPerOrigin.ensure(origin, [] { return HashMap<String, Vector<WebCore::BroadcastChannelIdentifier>> { }; }).iterator->value;
    auto& channelsForName = channelsForOrigin.ensure(name, [] { return Vector<WebCore::BroadcastChannelIdentifier> { }; }).iterator->value;
    channelsForName.append(identifier);

    if (channelsForName.size() == 1)
        networkProcessConnection().send(Messages::NetworkBroadcastChannelRegistry::RegisterChannel { origin, name }, 0);
}

void WebBroadcastChannelRegistry::unregisterChannel(const WebCore::ClientOrigin& origin, const String& name, WebCore::BroadcastChannelIdentifier identifier)
{
    auto channelsPerOriginIterator = m_channelsPerOrigin.find(origin);
    if (channelsPerOriginIterator == m_channelsPerOrigin.end())
        return;

    auto& channelsForOrigin = channelsPerOriginIterator->value;
    auto channelsForOriginIterator = channelsForOrigin.find(name);
    if (channelsForOriginIterator == channelsForOrigin.end())
        return;

    auto& channelIdentifiersForName = channelsForOriginIterator->value;
    if (!channelIdentifiersForName.removeFirst(identifier))
        return;
    if (!channelIdentifiersForName.isEmpty())
        return;

    channelsForOrigin.remove(channelsForOriginIterator);
    networkProcessConnection().send(Messages::NetworkBroadcastChannelRegistry::UnregisterChannel { origin, name }, 0);

    if (channelsForOrigin.isEmpty())
        m_channelsPerOrigin.remove(channelsPerOriginIterator);
}

void WebBroadcastChannelRegistry::postMessage(const WebCore::ClientOrigin& origin, const String& name, WebCore::BroadcastChannelIdentifier source, Ref<WebCore::SerializedScriptValue>&& message, CompletionHandler<void()>&& completionHandler)
{
    auto callbackAggregator = CallbackAggregator::create(WTFMove(completionHandler));
    postMessageLocally(origin, name, source, message.copyRef(), callbackAggregator.copyRef());
    networkProcessConnection().sendWithAsyncReply(Messages::NetworkBroadcastChannelRegistry::PostMessage { origin, name, WebCore::MessageWithMessagePorts { WTFMove(message), { } } }, [callbackAggregator] { }, 0);
}

void WebBroadcastChannelRegistry::postMessageLocally(const WebCore::ClientOrigin& origin, const String& name, std::optional<WebCore::BroadcastChannelIdentifier> sourceInProcess, Ref<WebCore::SerializedScriptValue>&& message, Ref<WTF::CallbackAggregator>&& callbackAggregator)
{
    auto channelsPerOriginIterator = m_channelsPerOrigin.find(origin);
    if (channelsPerOriginIterator == m_channelsPerOrigin.end())
        return;

    auto& channelsForOrigin = channelsPerOriginIterator->value;
    auto channelsForOriginIterator = channelsForOrigin.find(name);
    if (channelsForOriginIterator == channelsForOrigin.end())
        return;

    auto channelIdentifiersForName = channelsForOriginIterator->value;
    for (auto& channelIdentier : channelIdentifiersForName) {
        if (channelIdentier == sourceInProcess)
            continue;
        WebCore::BroadcastChannel::dispatchMessageTo(channelIdentier, message.copyRef(), [callbackAggregator] { });
    }
}

void WebBroadcastChannelRegistry::postMessageToRemote(const WebCore::ClientOrigin& origin, const String& name, WebCore::MessageWithMessagePorts&& message, CompletionHandler<void()>&& completionHandler)
{
    auto callbackAggregator = CallbackAggregator::create(WTFMove(completionHandler));
    postMessageLocally(origin, name, std::nullopt, *message.message, callbackAggregator.copyRef());
}

void WebBroadcastChannelRegistry::networkProcessCrashed()
{
    for (auto& [origin, channelsForOrigin] : m_channelsPerOrigin) {
        for (auto& name : channelsForOrigin.keys())
            networkProcessConnection().send(Messages::NetworkBroadcastChannelRegistry::RegisterChannel { origin, name }, 0);
    }
}

} // namespace WebKit
