//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_STREAM_REF
#define _CUDA_STREAM_REF

/*
    stream_ref synopsis
namespace cuda {
class stream_ref {
    using value_type = cudaStream_t;

    stream_ref() = default;
    stream_ref(cudaStream_t stream_) noexcept : stream(stream_) {}

    stream_ref(int) = delete;
    stream_ref(nullptr_t) = delete;

    [[nodiscard]] value_type get() const noexcept;

    void wait() const;

    [[nodiscard]] bool ready() const;

    [[nodiscard]] friend bool operator==(stream_ref, stream_ref);
    [[nodiscard]] friend bool operator!=(stream_ref, stream_ref);

private:
  cudaStream_t stream = 0; // exposition only
};
}  // cuda
*/

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda_runtime_api.h>

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/std/__exception/cuda_error.h>
#include <cuda/std/cstddef>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

/**
 * \brief A non-owning wrapper for a `cudaStream_t`.
 */
class stream_ref
{
protected:
  ::cudaStream_t __stream{0};

public:
  using value_type = ::cudaStream_t;

  /**
   * \brief Constructs a `stream_ref` of the "default" CUDA stream.
   *
   * For behavior of the default stream,
   * \see
   * https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
   *
   */
  _CCCL_HIDE_FROM_ABI stream_ref() = default;

  /**
   * \brief Constructs a `stream_ref` from a `cudaStream_t` handle.
   *
   * This constructor provides implicit conversion from `cudaStream_t`.
   *
   * \note: It is the callers responsibility to ensure the `stream_ref` does not
   * outlive the stream identified by the `cudaStream_t` handle.
   *
   */
  constexpr stream_ref(value_type __stream_) noexcept
      : __stream{__stream_}
  {}

  /// Disallow construction from an `int`, e.g., `0`.
  stream_ref(int) = delete;

  /// Disallow construction from `nullptr`.
  stream_ref(_CUDA_VSTD::nullptr_t) = delete;

  /**
   * \brief Compares two `stream_ref`s for equality
   *
   * \note Allows comparison with `cudaStream_t` due to implicit conversion to
   * `stream_ref`.
   *
   * \param lhs The first `stream_ref` to compare
   * \param rhs The second `stream_ref` to compare
   * \return true if equal, false if unequal
   */
  _CCCL_NODISCARD_FRIEND constexpr bool operator==(const stream_ref& __lhs, const stream_ref& __rhs) noexcept
  {
    return __lhs.__stream == __rhs.__stream;
  }

  /**
   * \brief Compares two `stream_ref`s for inequality
   *
   * \note Allows comparison with `cudaStream_t` due to implicit conversion to
   * `stream_ref`.
   *
   * \param lhs The first `stream_ref` to compare
   * \param rhs The second `stream_ref` to compare
   * \return true if unequal, false if equal
   */
  _CCCL_NODISCARD_FRIEND constexpr bool operator!=(const stream_ref& __lhs, const stream_ref& __rhs) noexcept
  {
    return __lhs.__stream != __rhs.__stream;
  }

  /// Returns the wrapped `cudaStream_t` handle.
  _CCCL_NODISCARD constexpr value_type get() const noexcept
  {
    return __stream;
  }

  /**
   * \brief Synchronizes the wrapped stream.
   *
   * \throws cuda::cuda_error if synchronization fails.
   *
   */
  void wait() const
  {
    _CCCL_TRY_CUDA_API(::cudaStreamSynchronize, "Failed to synchronize stream.", get());
  }

  /**
   * \brief Queries if all operations on the wrapped stream have completed.
   *
   * \throws cuda::cuda_error if the query fails.
   *
   * \return `true` if all operations have completed, or `false` if not.
   */
  _CCCL_NODISCARD bool ready() const
  {
    const auto __result = ::cudaStreamQuery(get());
    if (__result == ::cudaErrorNotReady)
    {
      return false;
    }
    switch (__result)
    {
      case ::cudaSuccess:
        break;
      default:
        ::cudaGetLastError(); // Clear CUDA error state
        ::cuda::__throw_cuda_error(__result, "Failed to query stream.");
    }
    return true;
  }

  /**
   * \brief Queries the priority of the wrapped stream.
   *
   * \throws cuda::cuda_error if the query fails.
   *
   * \return value representing the priority of the wrapped stream.
   */
  _CCCL_NODISCARD int priority() const
  {
    int __result = 0;
    _CCCL_TRY_CUDA_API(::cudaStreamGetPriority, "Failed to get stream priority", get(), &__result);
    return __result;
  }
};

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif //_CUDA_STREAM_REF
