Newer
Older
SafetyAuxiliary_AR / sdk / native / jni / include / opencv2 / gapi / plaidml / gplaidmlkernel.hpp
pengxianhong on 12 Jun 2024 3 KB 集成OpenCV
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2019 Intel Corporation
//


#ifndef OPENCV_GAPI_GPLAIDMLKERNEL_HPP
#define OPENCV_GAPI_GPLAIDMLKERNEL_HPP

#include <opencv2/gapi/gkernel.hpp>
#include <opencv2/gapi/garg.hpp>

namespace plaidml
{
namespace edsl
{
    class Tensor;
} // namespace edsl
} // namespace plaidml

namespace cv
{
namespace gapi
{
namespace plaidml
{

GAPI_EXPORTS cv::gapi::GBackend backend();

} // namespace plaidml
} // namespace gapi

struct GPlaidMLContext
{
    // Generic accessor API
    template<typename T>
    const T& inArg(int input) { return m_args.at(input).get<T>(); }

    // Syntax sugar
    const plaidml::edsl::Tensor& inTensor(int input)
    {
        return inArg<plaidml::edsl::Tensor>(input);
    }

    plaidml::edsl::Tensor& outTensor(int output)
    {
        return *(m_results.at(output).get<plaidml::edsl::Tensor*>());
    }

    std::vector<GArg> m_args;
    std::unordered_map<std::size_t, GArg> m_results;
};

class GAPI_EXPORTS GPlaidMLKernel
{
public:
    using F = std::function<void(GPlaidMLContext &)>;

    GPlaidMLKernel() = default;
    explicit GPlaidMLKernel(const F& f) : m_f(f) {}

    void apply(GPlaidMLContext &ctx) const
    {
        GAPI_Assert(m_f);
        m_f(ctx);
    }

protected:
    F m_f;
};


namespace detail
{

template<class T> struct plaidml_get_in;
template<> struct plaidml_get_in<cv::GMat>
{
    static const plaidml::edsl::Tensor& get(GPlaidMLContext& ctx, int idx)
    {
        return ctx.inTensor(idx);
    }
};

template<class T> struct plaidml_get_in
{
    static T get(GPlaidMLContext &ctx, int idx) { return ctx.inArg<T>(idx); }
};

template<class T> struct plaidml_get_out;
template<> struct plaidml_get_out<cv::GMat>
{
    static plaidml::edsl::Tensor& get(GPlaidMLContext& ctx, int idx)
    {
        return ctx.outTensor(idx);
    }
};

template<typename, typename, typename>
struct PlaidMLCallHelper;

template<typename Impl, typename... Ins, typename... Outs>
struct PlaidMLCallHelper<Impl, std::tuple<Ins...>, std::tuple<Outs...> >
{
    template<int... IIs, int... OIs>
    static void call_impl(GPlaidMLContext &ctx, detail::Seq<IIs...>, detail::Seq<OIs...>)
    {
        Impl::run(plaidml_get_in<Ins>::get(ctx, IIs)..., plaidml_get_out<Outs>::get(ctx, OIs)...);
    }

    static void call(GPlaidMLContext& ctx)
    {
        call_impl(ctx,
                  typename detail::MkSeq<sizeof...(Ins)>::type(),
                  typename detail::MkSeq<sizeof...(Outs)>::type());
    }
};

} // namespace detail

template<class Impl, class K>
class GPlaidMLKernelImpl: public cv::detail::PlaidMLCallHelper<Impl, typename K::InArgs, typename K::OutArgs>,
                          public cv::detail::KernelTag
{
    using P = detail::PlaidMLCallHelper<Impl, typename K::InArgs, typename K::OutArgs>;

public:
    using API = K;

    static cv::gapi::GBackend backend()  { return cv::gapi::plaidml::backend(); }
    static cv::GPlaidMLKernel kernel()   { return GPlaidMLKernel(&P::call);     }
};

#define GAPI_PLAIDML_KERNEL(Name, API) struct Name: public cv::GPlaidMLKernelImpl<Name, API>

} // namespace cv

#endif // OPENCV_GAPI_GPLAIDMLKERNEL_HPP