/*
 * Copyright (c) 2013-2018 Intel, Inc. All rights reserved
 * Copyright (c) 2017      Los Alamos National Security, LLC. All rights
 *                         reserved.
 * Copyright (c) 2019      Triad National Security, LLC. All rights
 *                         reserved.
 *
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 * $HEADER$
 */

#ifndef MTL_OFI_H_HAS_BEEN_INCLUDED
#define MTL_OFI_H_HAS_BEEN_INCLUDED

#include "ompi/mca/mtl/mtl.h"
#include "ompi/mca/mtl/base/base.h"
#include "opal/datatype/opal_convertor.h"
#include "opal/util/show_help.h"

#include <rdma/fabric.h>
#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>
#include <rdma/fi_errno.h>
#include <rdma/fi_tagged.h>

#include "ompi_config.h"
#include "ompi/proc/proc.h"
#include "ompi/mca/mtl/mtl.h"
#include "opal/class/opal_list.h"
#include "ompi/communicator/communicator.h"
#include "opal/datatype/opal_convertor.h"
#include "ompi/mca/mtl/base/base.h"
#include "ompi/mca/mtl/base/mtl_base_datatype.h"
#include "ompi/message/message.h"

#include "mtl_ofi_types.h"
#include "mtl_ofi_request.h"
#include "mtl_ofi_endpoint.h"
#include "mtl_ofi_compat.h"


BEGIN_C_DECLS

extern mca_mtl_ofi_module_t ompi_mtl_ofi;
extern mca_base_framework_t ompi_mtl_base_framework;

extern int ompi_mtl_ofi_del_procs(struct mca_mtl_base_module_t *mtl,
                                  size_t nprocs,
                                  struct ompi_proc_t **procs);

int ompi_mtl_ofi_progress_no_inline(void);

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_progress(void)
{
    ssize_t ret;
    int count = 0, i, events_read;
    struct fi_cq_err_entry error = { 0 };
    ompi_mtl_ofi_request_t *ofi_req = NULL;
    struct fi_cq_tagged_entry wc[ompi_mtl_ofi.ofi_progress_event_count];

    /**
     * Read the work completions from the CQ.
     * From the completion's op_context, we get the associated OFI request.
     * Call the request's callback.
     */
    while (true) {
        ret = fi_cq_read(ompi_mtl_ofi.cq, (void *)&wc, ompi_mtl_ofi.ofi_progress_event_count);
        if (ret > 0) {
            count+= ret;
            events_read = ret;
            for (i = 0; i < events_read; i++) {
                if (NULL != wc[i].op_context) {
                    ofi_req = TO_OFI_REQ(wc[i].op_context);
                    assert(ofi_req);
                    ret = ofi_req->event_callback(&wc[i], ofi_req);
                    if (OMPI_SUCCESS != ret) {
                        opal_output(0, "%s:%d: Error returned by request event callback: %zd.\n"
                                       "*** The Open MPI OFI MTL is aborting the MPI job (via exit(3)).\n",
                                       __FILE__, __LINE__, ret);
                        fflush(stderr);
                        exit(1);
                    }
                }
            }
        } else if (OPAL_UNLIKELY(ret == -FI_EAVAIL)) {
            /**
             * An error occured and is being reported via the CQ.
             * Read the error and forward it to the upper layer.
             */
            ret = fi_cq_readerr(ompi_mtl_ofi.cq,
                                &error,
                                0);
            if (0 > ret) {
                opal_output(0, "%s:%d: Error returned from fi_cq_readerr: %s(%zd).\n"
                               "*** The Open MPI OFI MTL is aborting the MPI job (via exit(3)).\n",
                               __FILE__, __LINE__, fi_strerror(-ret), ret);
                fflush(stderr);
                exit(1);
            }

            assert(error.op_context);
            ofi_req = TO_OFI_REQ(error.op_context);
            assert(ofi_req);
            ret = ofi_req->error_callback(&error, ofi_req);
            if (OMPI_SUCCESS != ret) {
                    opal_output(0, "%s:%d: Error returned by request error callback: %zd.\n"
                                   "*** The Open MPI OFI MTL is aborting the MPI job (via exit(3)).\n",
                                   __FILE__, __LINE__, ret);
                fflush(stderr);
                exit(1);
            }
        } else {
            if (ret == -FI_EAGAIN || ret == -EINTR) {
                break;
            } else {
                opal_output(0, "%s:%d: Error returned from fi_cq_read: %s(%zd).\n"
                               "*** The Open MPI OFI MTL is aborting the MPI job (via exit(3)).\n",
                               __FILE__, __LINE__, fi_strerror(-ret), ret);
                fflush(stderr);
                exit(1);
            }
        }
    }
    return count;
}

/**
 * When attempting to execute an OFI operation we need to handle
 * resource overrun cases. When a call to an OFI OP fails with -FI_EAGAIN
 * the OFI mtl will attempt to progress any pending Completion Queue
 * events that may prevent additional operations to be enqueued.
 * If the call to ofi progress is successful, then the function call
 * will be retried.
 */
#define MTL_OFI_RETRY_UNTIL_DONE(FUNC, RETURN)         \
    do {                                               \
        do {                                           \
            RETURN = FUNC;                             \
            if (OPAL_LIKELY(0 == RETURN)) {break;}     \
            if (OPAL_LIKELY(RETURN == -FI_EAGAIN)) {   \
                ompi_mtl_ofi_progress();               \
            }                                          \
        } while (OPAL_LIKELY(-FI_EAGAIN == RETURN));   \
    } while (0);

/* MTL interface functions */
int ompi_mtl_ofi_finalize(struct mca_mtl_base_module_t *mtl);

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_get_error(int error_num)
{
    int ret;

    switch (error_num) {
    case 0:
        ret = OMPI_SUCCESS;
        break;
    default:
        ret = OMPI_ERROR;
    }

    return ret;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_send_callback(struct fi_cq_tagged_entry *wc,
                           ompi_mtl_ofi_request_t *ofi_req)
{
    assert(ofi_req->completion_count > 0);
    ofi_req->completion_count--;
    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_send_error_callback(struct fi_cq_err_entry *error,
                                 ompi_mtl_ofi_request_t *ofi_req)
{
    switch(error->err) {
        case FI_ETRUNC:
            ofi_req->status.MPI_ERROR = MPI_ERR_TRUNCATE;
            break;
        default:
            ofi_req->status.MPI_ERROR = MPI_ERR_INTERN;
    }
    return ofi_req->event_callback(NULL, ofi_req);
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_send_ack_callback(struct fi_cq_tagged_entry *wc,
                               ompi_mtl_ofi_request_t *ofi_req)
{
    ompi_mtl_ofi_request_t *parent_req = ofi_req->parent;

    free(ofi_req);

    parent_req->event_callback(NULL, parent_req);

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_send_ack_error_callback(struct fi_cq_err_entry *error,
                                     ompi_mtl_ofi_request_t *ofi_req)
{
    ompi_mtl_ofi_request_t *parent_req = ofi_req->parent;

    free(ofi_req);

    parent_req->status.MPI_ERROR = MPI_ERR_INTERN;

    return parent_req->error_callback(error, parent_req);
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_isend_callback(struct fi_cq_tagged_entry *wc,
                            ompi_mtl_ofi_request_t *ofi_req)
{
    assert(ofi_req->completion_count > 0);
    ofi_req->completion_count--;

    if (0 == ofi_req->completion_count) {
        /* Request completed */
        if (OPAL_UNLIKELY(NULL != ofi_req->buffer)) {
            free(ofi_req->buffer);
            ofi_req->buffer = NULL;
        }

        ofi_req->super.ompi_req->req_status.MPI_ERROR =
            ofi_req->status.MPI_ERROR;

        ofi_req->super.completion_callback(&ofi_req->super);
    }

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_ssend_recv(ompi_mtl_ofi_request_t *ack_req,
                  struct ompi_communicator_t *comm,
                  fi_addr_t *src_addr,
                  ompi_mtl_ofi_request_t *ofi_req,
                  mca_mtl_ofi_endpoint_t *endpoint,
                  uint64_t *match_bits,
                  int tag)
{
        ssize_t ret = OMPI_SUCCESS;
        ack_req = malloc(sizeof(ompi_mtl_ofi_request_t));

        assert(ack_req);

        ack_req->parent = ofi_req;
        ack_req->event_callback = ompi_mtl_ofi_send_ack_callback;
        ack_req->error_callback = ompi_mtl_ofi_send_ack_error_callback;

        ofi_req->completion_count += 1;

        MTL_OFI_RETRY_UNTIL_DONE(fi_trecv(ompi_mtl_ofi.ep,
                                          NULL,
                                          0,
                                          NULL,
                                          *src_addr,
                                          *match_bits | ompi_mtl_ofi.sync_send_ack,
                                          0, /* Exact match, no ignore bits */
                                          (void *) &ack_req->ctx), ret);
        if (OPAL_UNLIKELY(0 > ret)) {
            opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                                "%s:%d: fi_trecv failed: %s(%zd)",
                                __FILE__, __LINE__, fi_strerror(-ret), ret);
            free(ack_req);
            return ompi_mtl_ofi_get_error(ret);
        }

         /* The SYNC_SEND tag bit is set for the send operation only.*/
        MTL_OFI_SET_SYNC_SEND(*match_bits);
        return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_send(struct mca_mtl_base_module_t *mtl,
                  struct ompi_communicator_t *comm,
                  int dest,
                  int tag,
                  struct opal_convertor_t *convertor,
                  mca_pml_base_send_mode_t mode)
{
    ssize_t ret = OMPI_SUCCESS;
    ompi_mtl_ofi_request_t ofi_req;
    int ompi_ret;
    void *start;
    bool free_after;
    size_t length;
    uint64_t match_bits;
    ompi_proc_t *ompi_proc = NULL;
    mca_mtl_ofi_endpoint_t *endpoint = NULL;
    ompi_mtl_ofi_request_t *ack_req = NULL; /* For synchronous send */
    fi_addr_t src_addr = 0;

    /**
     * Create a send request, start it and wait until it completes.
     */
    ofi_req.event_callback = ompi_mtl_ofi_send_callback;
    ofi_req.error_callback = ompi_mtl_ofi_send_error_callback;

    ompi_proc = ompi_comm_peer_lookup(comm, dest);
    endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc);

    ompi_ret = ompi_mtl_datatype_pack(convertor, &start, &length, &free_after);
    if (OPAL_UNLIKELY(OMPI_SUCCESS != ompi_ret)) {
        return ompi_ret;
    }

    ofi_req.buffer = (free_after) ? start : NULL;
    ofi_req.length = length;
    ofi_req.status.MPI_ERROR = OMPI_SUCCESS;
    ofi_req.completion_count = 0;

    if (OPAL_UNLIKELY(length > endpoint->mtl_ofi_module->max_msg_size)) {
        opal_show_help("help-mtl-ofi.txt",
            "message too big", false,
            length, endpoint->mtl_ofi_module->max_msg_size);
        return OMPI_ERROR;
    }

    if (ompi_mtl_ofi.fi_cq_data) {
        match_bits = mtl_ofi_create_send_tag_CQD(comm->c_contextid, tag);
        src_addr = endpoint->peer_fiaddr;
    } else {
        match_bits = mtl_ofi_create_send_tag(comm->c_contextid,
                                             comm->c_my_rank, tag);
        /* src_addr is ignored when FI_DIRECTED_RECV is not supported */
    }

    if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
        ofi_req.status.MPI_ERROR = ompi_mtl_ofi_ssend_recv(ack_req, comm, &src_addr,
                                                           &ofi_req, endpoint,
                                                           &match_bits, tag);
        if (OPAL_UNLIKELY(ofi_req.status.MPI_ERROR != OMPI_SUCCESS))
            goto free_request_buffer;
    }

    if (ompi_mtl_ofi.max_inject_size >= length) {
        if (ompi_mtl_ofi.fi_cq_data) {
            MTL_OFI_RETRY_UNTIL_DONE(fi_tinjectdata(ompi_mtl_ofi.ep,
                                            start,
                                            length,
                                            comm->c_my_rank,
                                            endpoint->peer_fiaddr,
                                            match_bits), ret);
        } else {
            MTL_OFI_RETRY_UNTIL_DONE(fi_tinject(ompi_mtl_ofi.ep,
                                            start,
                                            length,
                                            endpoint->peer_fiaddr,
                                            match_bits), ret);
        }
        if (OPAL_UNLIKELY(0 > ret)) {
            char *fi_api = ompi_mtl_ofi.fi_cq_data ? "fi_tinjectddata" : "fi_tinject";
            opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                                "%s:%d: %s failed: %s(%zd)",
                                __FILE__, __LINE__,fi_api, fi_strerror(-ret), ret);

            if (ack_req) {
                fi_cancel((fid_t)ompi_mtl_ofi.ep, &ack_req->ctx);
                free(ack_req);
            }

            ofi_req.status.MPI_ERROR = ompi_mtl_ofi_get_error(ret);
            goto free_request_buffer;
        }
    } else {
        ofi_req.completion_count += 1;
        if (ompi_mtl_ofi.fi_cq_data) {
            MTL_OFI_RETRY_UNTIL_DONE(fi_tsenddata(ompi_mtl_ofi.ep,
                                          start,
                                          length,
                                          NULL,
                                          comm->c_my_rank,
                                          endpoint->peer_fiaddr,
                                          match_bits,
                                          (void *) &ofi_req.ctx), ret);
        } else {
            MTL_OFI_RETRY_UNTIL_DONE(fi_tsend(ompi_mtl_ofi.ep,
                                          start,
                                          length,
                                          NULL,
                                          endpoint->peer_fiaddr,
                                          match_bits,
                                          (void *) &ofi_req.ctx), ret);
        }
        if (OPAL_UNLIKELY(0 > ret)) {
            char *fi_api = ompi_mtl_ofi.fi_cq_data ? "fi_tsendddata" : "fi_send";
            opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                                "%s:%d: %s failed: %s(%zd)",
                                __FILE__, __LINE__,fi_api, fi_strerror(-ret), ret);
            free(fi_api);

            ofi_req.status.MPI_ERROR = ompi_mtl_ofi_get_error(ret);
            goto free_request_buffer;
        }
    }

    /**
     * Wait until the request is completed.
     * ompi_mtl_ofi_send_callback() updates this variable.
     */
    while (0 < ofi_req.completion_count) {
        ompi_mtl_ofi_progress();
    }

free_request_buffer:
    if (OPAL_UNLIKELY(NULL != ofi_req.buffer)) {
        free(ofi_req.buffer);
    }

    return ofi_req.status.MPI_ERROR;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_isend(struct mca_mtl_base_module_t *mtl,
                   struct ompi_communicator_t *comm,
                   int dest,
                   int tag,
                   struct opal_convertor_t *convertor,
                   mca_pml_base_send_mode_t mode,
                   bool blocking,
                   mca_mtl_request_t *mtl_request)
{
    ssize_t ret = OMPI_SUCCESS;
    ompi_mtl_ofi_request_t *ofi_req = (ompi_mtl_ofi_request_t *) mtl_request;
    int ompi_ret;
    void *start;
    size_t length;
    bool free_after;
    uint64_t match_bits;
    ompi_proc_t *ompi_proc = NULL;
    mca_mtl_ofi_endpoint_t *endpoint = NULL;
    ompi_mtl_ofi_request_t *ack_req = NULL; /* For synchronous send */
    fi_addr_t src_addr = 0;

    ofi_req->event_callback = ompi_mtl_ofi_isend_callback;
    ofi_req->error_callback = ompi_mtl_ofi_send_error_callback;

    ompi_proc = ompi_comm_peer_lookup(comm, dest);
    endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc);

    ompi_ret = ompi_mtl_datatype_pack(convertor, &start, &length, &free_after);
    if (OPAL_UNLIKELY(OMPI_SUCCESS != ompi_ret)) return ompi_ret;

    ofi_req->buffer = (free_after) ? start : NULL;
    ofi_req->length = length;
    ofi_req->status.MPI_ERROR = OMPI_SUCCESS;
    ofi_req->completion_count = 1;

    if (OPAL_UNLIKELY(length > endpoint->mtl_ofi_module->max_msg_size)) {
        opal_show_help("help-mtl-ofi.txt",
            "message too big", false,
            length, endpoint->mtl_ofi_module->max_msg_size);
        return OMPI_ERROR;
    }

    if (ompi_mtl_ofi.fi_cq_data) {
        match_bits = mtl_ofi_create_send_tag_CQD(comm->c_contextid, tag);
        src_addr = endpoint->peer_fiaddr;
    } else {
        match_bits = mtl_ofi_create_send_tag(comm->c_contextid,
                          comm->c_my_rank, tag);
        /* src_addr is ignored when FI_DIRECTED_RECV  is not supported */
    }

    if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
        ofi_req->status.MPI_ERROR = ompi_mtl_ofi_ssend_recv(ack_req, comm, &src_addr,
                                                           ofi_req, endpoint,
                                                           &match_bits, tag);
        if (OPAL_UNLIKELY(ofi_req->status.MPI_ERROR != OMPI_SUCCESS))
            goto free_request_buffer;
    }

    if (ompi_mtl_ofi.fi_cq_data) {
        MTL_OFI_RETRY_UNTIL_DONE(fi_tsenddata(ompi_mtl_ofi.ep,
                                      start,
                                      length,
                                      NULL,
                                      comm->c_my_rank,
                                      endpoint->peer_fiaddr,
                                      match_bits,
                                      (void *) &ofi_req->ctx), ret);
    } else {
        MTL_OFI_RETRY_UNTIL_DONE(fi_tsend(ompi_mtl_ofi.ep,
                                      start,
                                      length,
                                      NULL,
                                      endpoint->peer_fiaddr,
                                      match_bits,
                                      (void *) &ofi_req->ctx), ret);
    }
    if (OPAL_UNLIKELY(0 > ret)) {
        char *fi_api;
        if (ompi_mtl_ofi.fi_cq_data) {
                asprintf( &fi_api, "fi_tsendddata") ;
        }
        else {
                asprintf( &fi_api, "fi_send") ;
        }
        opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                            "%s:%d: %s failed: %s(%zd)",
                            __FILE__, __LINE__,fi_api, fi_strerror(-ret), ret);
        free(fi_api);
        ofi_req->status.MPI_ERROR = ompi_mtl_ofi_get_error(ret);
    }

free_request_buffer:
    if (OPAL_UNLIKELY(OMPI_SUCCESS != ofi_req->status.MPI_ERROR
            && NULL != ofi_req->buffer)) {
        free(ofi_req->buffer);
    }

    return ofi_req->status.MPI_ERROR;
}

/**
 * Called when a completion for a posted recv is received.
 */
__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_recv_callback(struct fi_cq_tagged_entry *wc,
                           ompi_mtl_ofi_request_t *ofi_req)
{
    int ompi_ret;
    ssize_t ret;
    ompi_proc_t *ompi_proc = NULL;
    mca_mtl_ofi_endpoint_t *endpoint = NULL;
    int src = mtl_ofi_get_source(wc);
    ompi_status_public_t *status = NULL;
    struct fi_msg_tagged tagged_msg;

    assert(ofi_req->super.ompi_req);
    status = &ofi_req->super.ompi_req->req_status;

    /**
     * Any event associated with a request starts it.
     * This prevents a started request from being cancelled.
     */
    ofi_req->req_started = true;

    status->MPI_SOURCE = src;
    status->MPI_TAG = MTL_OFI_GET_TAG(wc->tag);
    status->_ucount = wc->len;

    if (OPAL_UNLIKELY(wc->len > ofi_req->length)) {
        opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                            "truncate expected: %ld %ld",
                            wc->len, ofi_req->length);
        status->MPI_ERROR = MPI_ERR_TRUNCATE;
    }

    /**
     * Unpack data into recv buffer if necessary.
     */
    if (OPAL_UNLIKELY(ofi_req->buffer)) {
        ompi_ret = ompi_mtl_datatype_unpack(ofi_req->convertor,
                                            ofi_req->buffer,
                                            wc->len);
        if (OPAL_UNLIKELY(OMPI_SUCCESS != ompi_ret)) {
            opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                                "%s:%d: ompi_mtl_datatype_unpack failed: %d",
                                __FILE__, __LINE__, ompi_ret);
            status->MPI_ERROR = ompi_ret;
        }
    }

    /**
    * We can only accept MTL_OFI_SYNC_SEND in the standard recv callback.
    * MTL_OFI_SYNC_SEND_ACK should only be received in the send_ack
    * callback.
    */
    assert(!MTL_OFI_IS_SYNC_SEND_ACK(wc->tag));

    /**
     * If this recv is part of an MPI_Ssend operation, then we send an
     * acknowledgment back to the sender.
     * The ack message is sent without generating a completion event in
     * the completion queue by not setting FI_COMPLETION in the flags to
     * fi_tsendmsg(FI_SELECTIVE_COMPLETION).
     * This is done since the 0 byte message requires no
     * notification on the send side for a successful completion.
     * If a failure occurs the provider will notify the error
     * in the cq_readerr during OFI progress. Once the message has been
     * successfully processed the request is marked as completed.
     */
    if (OPAL_UNLIKELY(MTL_OFI_IS_SYNC_SEND(wc->tag))) {
        /**
         * If the recv request was posted for any source,
         * we need to extract the source's actual address.
         */
        if (ompi_mtl_ofi.any_addr == ofi_req->remote_addr) {
            ompi_proc = ompi_comm_peer_lookup(ofi_req->comm, src);
            endpoint = ompi_mtl_ofi_get_endpoint(ofi_req->mtl, ompi_proc);
            ofi_req->remote_addr = endpoint->peer_fiaddr;
        }

        tagged_msg.msg_iov = NULL;
        tagged_msg.desc = NULL;
        tagged_msg.iov_count = 0;
        tagged_msg.addr = ofi_req->remote_addr;
        /**
        * We must continue to use the user's original tag but remove the
        * sync_send protocol tag bit and instead apply the sync_send_ack
        * tag bit to complete the initator's sync send receive.
        */
        tagged_msg.tag = (wc->tag | ompi_mtl_ofi.sync_send_ack) & ~ompi_mtl_ofi.sync_send;
        tagged_msg.context = NULL;
        tagged_msg.data = 0;

        MTL_OFI_RETRY_UNTIL_DONE(fi_tsendmsg(ompi_mtl_ofi.ep,
                                 &tagged_msg, 0), ret);
        if (OPAL_UNLIKELY(0 > ret)) {
            opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                                "%s:%d: fi_tsendmsg failed: %s(%zd)",
                                __FILE__, __LINE__, fi_strerror(-ret), ret);
            status->MPI_ERROR = OMPI_ERROR;
        }
    }

    ofi_req->super.completion_callback(&ofi_req->super);

    return OMPI_SUCCESS;
}

/**
 * Called when an error occured on a recv request.
 */
__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_recv_error_callback(struct fi_cq_err_entry *error,
                                 ompi_mtl_ofi_request_t *ofi_req)
{
    ompi_status_public_t *status;
    assert(ofi_req->super.ompi_req);
    status = &ofi_req->super.ompi_req->req_status;
    status->MPI_TAG = MTL_OFI_GET_TAG(ofi_req->match_bits);
    status->MPI_SOURCE = mtl_ofi_get_source((struct fi_cq_tagged_entry *) error);

    switch (error->err) {
        case FI_ETRUNC:
            status->MPI_ERROR = MPI_ERR_TRUNCATE;
            break;
        case FI_ECANCELED:
            status->_cancelled = true;
            break;
        default:
            status->MPI_ERROR = MPI_ERR_INTERN;
    }

    ofi_req->super.completion_callback(&ofi_req->super);
    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_irecv(struct mca_mtl_base_module_t *mtl,
                   struct ompi_communicator_t *comm,
                   int src,
                   int tag,
                   struct opal_convertor_t *convertor,
                   mca_mtl_request_t *mtl_request)
{
    int ompi_ret = OMPI_SUCCESS;
    ssize_t ret;
    uint64_t match_bits, mask_bits;
    fi_addr_t remote_addr = ompi_mtl_ofi.any_addr;
    ompi_proc_t *ompi_proc = NULL;
    mca_mtl_ofi_endpoint_t *endpoint = NULL;
    ompi_mtl_ofi_request_t *ofi_req = (ompi_mtl_ofi_request_t*) mtl_request;
    void *start;
    size_t length;
    bool free_after;


    if (ompi_mtl_ofi.fi_cq_data) {
        if (MPI_ANY_SOURCE != src) {
            ompi_proc = ompi_comm_peer_lookup(comm, src);
            endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc);
            remote_addr = endpoint->peer_fiaddr;
        }

        mtl_ofi_create_recv_tag_CQD(&match_bits, &mask_bits, comm->c_contextid,
                                    tag);
    } else {
        mtl_ofi_create_recv_tag(&match_bits, &mask_bits, comm->c_contextid, src,
                                tag);
        /* src_addr is ignored when FI_DIRECTED_RECV is not used */
    }

    ompi_ret = ompi_mtl_datatype_recv_buf(convertor,
                                          &start,
                                          &length,
                                          &free_after);
    if (OPAL_UNLIKELY(OMPI_SUCCESS != ompi_ret)) {
        return ompi_ret;
    }

    ofi_req->type = OMPI_MTL_OFI_RECV;
    ofi_req->event_callback = ompi_mtl_ofi_recv_callback;
    ofi_req->error_callback = ompi_mtl_ofi_recv_error_callback;
    ofi_req->comm = comm;
    ofi_req->buffer = (free_after) ? start : NULL;
    ofi_req->length = length;
    ofi_req->convertor = convertor;
    ofi_req->req_started = false;
    ofi_req->status.MPI_ERROR = OMPI_SUCCESS;
    ofi_req->remote_addr = remote_addr;
    ofi_req->match_bits = match_bits;

    MTL_OFI_RETRY_UNTIL_DONE(fi_trecv(ompi_mtl_ofi.ep,
                                      start,
                                      length,
                                      NULL,
                                      remote_addr,
                                      match_bits,
                                      mask_bits,
                                      (void *)&ofi_req->ctx), ret);
    if (OPAL_UNLIKELY(0 > ret)) {
        if (NULL != ofi_req->buffer) {
            free(ofi_req->buffer);
        }
        opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                            "%s:%d: fi_trecv failed: %s(%zd)",
                            __FILE__, __LINE__, fi_strerror(-ret), ret);
        return ompi_mtl_ofi_get_error(ret);
    }

    return OMPI_SUCCESS;
}

/**
 * Called when a mrecv request completes.
 */
__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_mrecv_callback(struct fi_cq_tagged_entry *wc,
                            ompi_mtl_ofi_request_t *ofi_req)
{
    struct mca_mtl_request_t *mrecv_req = ofi_req->mrecv_req;
    ompi_status_public_t *status = &mrecv_req->ompi_req->req_status;
    status->MPI_SOURCE = mtl_ofi_get_source(wc);
    status->MPI_TAG = MTL_OFI_GET_TAG(wc->tag);
    status->MPI_ERROR = MPI_SUCCESS;
    status->_ucount = wc->len;

    free(ofi_req);

    mrecv_req->completion_callback(mrecv_req);

    return OMPI_SUCCESS;
}

/**
 * Called when an error occured on a mrecv request.
 */
__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_mrecv_error_callback(struct fi_cq_err_entry *error,
                                  ompi_mtl_ofi_request_t *ofi_req)
{
    struct mca_mtl_request_t *mrecv_req = ofi_req->mrecv_req;
    ompi_status_public_t *status = &mrecv_req->ompi_req->req_status;
    status->MPI_TAG = MTL_OFI_GET_TAG(ofi_req->match_bits);
    status->MPI_SOURCE = mtl_ofi_get_source((struct fi_cq_tagged_entry  *) error);

    switch (error->err) {
        case FI_ETRUNC:
            status->MPI_ERROR = MPI_ERR_TRUNCATE;
            break;
        case FI_ECANCELED:
            status->_cancelled = true;
            break;
        default:
            status->MPI_ERROR = MPI_ERR_INTERN;
    }

    free(ofi_req);

    mrecv_req->completion_callback(mrecv_req);

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_imrecv(struct mca_mtl_base_module_t *mtl,
                    struct opal_convertor_t *convertor,
                    struct ompi_message_t **message,
                    struct mca_mtl_request_t *mtl_request)
{
    ompi_mtl_ofi_request_t *ofi_req =
        (ompi_mtl_ofi_request_t *)(*message)->req_ptr;
    void *start;
    size_t length;
    bool free_after;
    struct iovec iov;
    struct fi_msg_tagged msg;
    int ompi_ret;
    ssize_t ret;
    uint64_t msgflags = FI_CLAIM | FI_COMPLETION;

    ompi_ret = ompi_mtl_datatype_recv_buf(convertor,
                                          &start,
                                          &length,
                                          &free_after);
    if (OPAL_UNLIKELY(OMPI_SUCCESS != ompi_ret)) {
        return ompi_ret;
    }

    ofi_req->type = OMPI_MTL_OFI_RECV;
    ofi_req->event_callback = ompi_mtl_ofi_mrecv_callback;
    ofi_req->error_callback = ompi_mtl_ofi_mrecv_error_callback;
    ofi_req->buffer = (free_after) ? start : NULL;
    ofi_req->length = length;
    ofi_req->convertor = convertor;
    ofi_req->status.MPI_ERROR = OMPI_SUCCESS;
    ofi_req->mrecv_req = mtl_request;

    /**
     * fi_trecvmsg with FI_CLAIM
     */
    iov.iov_base = start;
    iov.iov_len = length;
    msg.msg_iov = &iov;
    msg.desc = NULL;
    msg.iov_count = 1;
    msg.addr = 0;
    msg.tag = ofi_req->match_bits;
    msg.ignore = ofi_req->mask_bits;
    msg.context = (void *)&ofi_req->ctx;
    msg.data = 0;

    MTL_OFI_RETRY_UNTIL_DONE(fi_trecvmsg(ompi_mtl_ofi.ep, &msg, msgflags), ret);
    if (OPAL_UNLIKELY(0 > ret)) {
        opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                            "%s:%d: fi_trecvmsg failed: %s(%zd)",
                            __FILE__, __LINE__, fi_strerror(-ret), ret);
        return ompi_mtl_ofi_get_error(ret);
    }

    return OMPI_SUCCESS;
}

/**
 * Called when a probe request completes.
 */
__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_probe_callback(struct fi_cq_tagged_entry *wc,
                            ompi_mtl_ofi_request_t *ofi_req)
{
    ofi_req->match_state = 1;
    ofi_req->match_bits = wc->tag;
    ofi_req->status.MPI_SOURCE = mtl_ofi_get_source(wc);
    ofi_req->status.MPI_TAG = MTL_OFI_GET_TAG(wc->tag);
    ofi_req->status.MPI_ERROR = MPI_SUCCESS;
    ofi_req->status._ucount = wc->len;
    ofi_req->completion_count--;

    return OMPI_SUCCESS;
}

/**
 * Called when a probe request encounters an error.
 */
__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_probe_error_callback(struct fi_cq_err_entry *error,
                                  ompi_mtl_ofi_request_t *ofi_req)
{
    ofi_req->status.MPI_ERROR = MPI_ERR_INTERN;
    ofi_req->completion_count--;

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_iprobe(struct mca_mtl_base_module_t *mtl,
                    struct ompi_communicator_t *comm,
                    int src,
                    int tag,
                    int *flag,
                    struct ompi_status_public_t *status)
{
    struct ompi_mtl_ofi_request_t ofi_req;
    ompi_proc_t *ompi_proc = NULL;
    mca_mtl_ofi_endpoint_t *endpoint = NULL;
    fi_addr_t remote_proc = ompi_mtl_ofi.any_addr;
    uint64_t match_bits, mask_bits;
    ssize_t ret;
    struct fi_msg_tagged msg;
    uint64_t msgflags = FI_PEEK | FI_COMPLETION;

    if (ompi_mtl_ofi.fi_cq_data) {
     /* If the source is known, use its peer_fiaddr. */
        if (MPI_ANY_SOURCE != src) {
            ompi_proc = ompi_comm_peer_lookup( comm, src );
            endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc);
            remote_proc = endpoint->peer_fiaddr;
        }

        mtl_ofi_create_recv_tag_CQD(&match_bits, &mask_bits, comm->c_contextid,
                                    tag);
    }
    else {
        mtl_ofi_create_recv_tag(&match_bits, &mask_bits, comm->c_contextid, src,
                                tag);
        /* src_addr is ignored when FI_DIRECTED_RECV is not used */
    }

    /**
     * fi_trecvmsg with FI_PEEK:
     * Initiate a search for a match in the hardware or software queue.
     * The search can complete immediately with -ENOMSG.
     * If successful, libfabric will enqueue a context entry into the completion
     * queue to make the search nonblocking.  This code will poll until the
     * entry is enqueued.
     */
    msg.msg_iov = NULL;
    msg.desc = NULL;
    msg.iov_count = 0;
    msg.addr = remote_proc;
    msg.tag = match_bits;
    msg.ignore = mask_bits;
    msg.context = (void *)&ofi_req.ctx;
    msg.data = 0;

    ofi_req.type = OMPI_MTL_OFI_PROBE;
    ofi_req.event_callback = ompi_mtl_ofi_probe_callback;
    ofi_req.error_callback = ompi_mtl_ofi_probe_error_callback;
    ofi_req.completion_count = 1;
    ofi_req.match_state = 0;

    MTL_OFI_RETRY_UNTIL_DONE(fi_trecvmsg(ompi_mtl_ofi.ep, &msg, msgflags), ret);
    if (-FI_ENOMSG == ret) {
        /**
         * The search request completed but no matching message was found.
         */
        *flag = 0;
        return OMPI_SUCCESS;
    } else if (OPAL_UNLIKELY(0 > ret)) {
        opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                            "%s:%d: fi_trecvmsg failed: %s(%zd)",
                            __FILE__, __LINE__, fi_strerror(-ret), ret);
        return ompi_mtl_ofi_get_error(ret);
    }

    while (0 < ofi_req.completion_count) {
        opal_progress();
    }

    *flag = ofi_req.match_state;
    if (1 == *flag) {
        if (MPI_STATUS_IGNORE != status) {
            *status = ofi_req.status;
        }
    }

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_improbe(struct mca_mtl_base_module_t *mtl,
                     struct ompi_communicator_t *comm,
                     int src,
                     int tag,
                     int *matched,
                     struct ompi_message_t **message,
                     struct ompi_status_public_t *status)
{
    struct ompi_mtl_ofi_request_t *ofi_req;
    ompi_proc_t *ompi_proc = NULL;
    mca_mtl_ofi_endpoint_t *endpoint = NULL;
    fi_addr_t remote_proc = ompi_mtl_ofi.any_addr;
    uint64_t match_bits, mask_bits;
    ssize_t ret;
    struct fi_msg_tagged msg;
    uint64_t msgflags = FI_PEEK | FI_CLAIM | FI_COMPLETION;

    ofi_req = malloc(sizeof *ofi_req);
    if (NULL == ofi_req) {
        return OMPI_ERROR;
    }

    /**
     * If the source is known, use its peer_fiaddr.
     */

    if (ompi_mtl_ofi.fi_cq_data) {
        if (MPI_ANY_SOURCE != src) {
            ompi_proc = ompi_comm_peer_lookup( comm, src );
            endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc);
            remote_proc = endpoint->peer_fiaddr;
        }

        mtl_ofi_create_recv_tag_CQD(&match_bits, &mask_bits, comm->c_contextid,
                                    tag);
    }
    else {
        /* src_addr is ignored when FI_DIRECTED_RECV is not used */
        mtl_ofi_create_recv_tag(&match_bits, &mask_bits, comm->c_contextid, src,
                                tag);
    }

    /**
     * fi_trecvmsg with FI_PEEK and FI_CLAIM:
     * Initiate a search for a match in the hardware or software queue.
     * The search can complete immediately with -ENOMSG.
     * If successful, libfabric will enqueue a context entry into the completion
     * queue to make the search nonblocking.  This code will poll until the
     * entry is enqueued.
     */
    msg.msg_iov = NULL;
    msg.desc = NULL;
    msg.iov_count = 0;
    msg.addr = remote_proc;
    msg.tag = match_bits;
    msg.ignore = mask_bits;
    msg.context = (void *)&ofi_req->ctx;
    msg.data = 0;

    ofi_req->type = OMPI_MTL_OFI_PROBE;
    ofi_req->event_callback = ompi_mtl_ofi_probe_callback;
    ofi_req->error_callback = ompi_mtl_ofi_probe_error_callback;
    ofi_req->completion_count = 1;
    ofi_req->match_state = 0;
    ofi_req->mask_bits = mask_bits;

    MTL_OFI_RETRY_UNTIL_DONE(fi_trecvmsg(ompi_mtl_ofi.ep, &msg, msgflags), ret);
    if (-FI_ENOMSG == ret) {
        /**
         * The search request completed but no matching message was found.
         */
        *matched = 0;
        free(ofi_req);
        return OMPI_SUCCESS;
    } else if (OPAL_UNLIKELY(0 > ret)) {
        opal_output_verbose(1, ompi_mtl_base_framework.framework_output,
                            "%s:%d: fi_trecvmsg failed: %s(%zd)",
                            __FILE__, __LINE__, fi_strerror(-ret), ret);
        free(ofi_req);
        return ompi_mtl_ofi_get_error(ret);
    }

    while (0 < ofi_req->completion_count) {
        opal_progress();
    }

    *matched = ofi_req->match_state;
    if (1 == *matched) {
        if (MPI_STATUS_IGNORE != status) {
            *status = ofi_req->status;
        }

        (*message) = ompi_message_alloc();
        if (NULL == (*message)) {
            return OMPI_ERR_OUT_OF_RESOURCE;
        }

        (*message)->comm = comm;
        (*message)->req_ptr = ofi_req;
        (*message)->peer = ofi_req->status.MPI_SOURCE;
        (*message)->count = ofi_req->status._ucount;

    } else {
        (*message) = MPI_MESSAGE_NULL;
        free(ofi_req);
    }

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_cancel(struct mca_mtl_base_module_t *mtl,
                    mca_mtl_request_t *mtl_request,
                    int flag)
{
    int ret;
    ompi_mtl_ofi_request_t *ofi_req = (ompi_mtl_ofi_request_t*) mtl_request;

    switch (ofi_req->type) {
        case OMPI_MTL_OFI_SEND:
            /**
             * Cannot cancel sends yet
             */
            break;

        case OMPI_MTL_OFI_RECV:
            /**
             * Cancel a receive request only if it hasn't been matched yet.
             * The event queue needs to be drained to make sure there isn't
             * any pending receive completion event.
             */
            ompi_mtl_ofi_progress();

            if (!ofi_req->req_started) {
                ret = fi_cancel((fid_t)ompi_mtl_ofi.ep, &ofi_req->ctx);
                if (0 == ret) {
                    /**
                     * Wait for the request to be cancelled.
                     */
                    while (!ofi_req->super.ompi_req->req_status._cancelled) {
                        opal_progress();
                        if (ofi_req->req_started)
                            goto ofi_cancel_not_possible;
                    }
                } else {
ofi_cancel_not_possible:
                    /**
                     * Could not cancel the request.
                     */
                    ofi_req->super.ompi_req->req_status._cancelled = false;
                }
            }
            break;

        default:
            return OMPI_ERROR;
    }

    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_add_comm(struct mca_mtl_base_module_t *mtl,
                      struct ompi_communicator_t *comm)
{
    return OMPI_SUCCESS;
}

__opal_attribute_always_inline__ static inline int
ompi_mtl_ofi_del_comm(struct mca_mtl_base_module_t *mtl,
                      struct ompi_communicator_t *comm)
{
    return OMPI_SUCCESS;
}

END_C_DECLS

#endif  /* MTL_OFI_H_HAS_BEEN_INCLUDED */
