00001 #include "MPISynchroShuffleAllToAllComm.h"
00002 #include <algorithm>
00003
00004 using std::cerr;
00005 using std::copy;
00006
00007 MPISynchroShuffleAllToAllComm::MPISynchroShuffleAllToAllComm(MPIInputBufferVector & mpiInputBuffers,
00008 MPIOutputBufferVector & mpiOutputBuffers,
00009 MPI::Intracomm & comm,
00010 vector<bool> &incomingConnections,
00011 vector<bool> &outgoingConnections) :
00012 MPIAllToAllCommunicator(mpiInputBuffers, mpiOutputBuffers, comm,
00013 incomingConnections, outgoingConnections)
00014
00015 {
00016 needsToSend.resize(numNodes);
00017 needsToReceive.resize(numNodes);
00018 }
00019
00020 MPISynchroShuffleAllToAllComm::~MPISynchroShuffleAllToAllComm()
00021 {}
00022
00023 void MPISynchroShuffleAllToAllComm::prepare()
00024 {
00025 copy(outgoing_connections.begin(), outgoing_connections.end(), needsToSend.begin());
00026 copy(outgoing_connections.begin(), outgoing_connections.end(), hasNextToSend.begin());
00027 copy(incoming_connections.begin(), incoming_connections.end(), needsToReceive.begin());
00028 }
00029
00030 void MPISynchroShuffleAllToAllComm::doExchangeAlgorithm()
00031 {
00032
00033 for (int i = 0 ; i < numNodes ; ++i) {
00034 outputBuffers[i].setFinishedFlag(!hasNextToSend[i]);
00035 }
00036
00037 MPIExchangeBlocksInfo & input_blocks = inputBuffers.getMPIExchangeBlocksInfo();
00038
00039 MPIExchangeBlocksInfo & output_blocks = outputBuffers.getMPIExchangeBlocksInfo();
00040
00041 int myid = mpi_comm.Get_rank();
00042 int tag = 100;
00043
00044
00045 for (int i = 1; i < numNodes; ++i) {
00046 int to = (myid+i) % numNodes;
00047 int from = (myid+numNodes-i) % numNodes;
00048 int order;
00049 if (i % 2 == 0) {
00050 order = ( myid + ((myid - myid % i) / i) % 2 ) % 2 ;
00051 } else {
00052 order = myid % 2;
00053 }
00054 if (order) {
00055 if (needsToSend[to]) {
00056 mpi_comm.Send(output_blocks.buffers[to],output_blocks.counts[to], output_blocks.datatypes[to], to, tag);
00057 }
00058 if (needsToReceive[from]) {
00059 mpi_comm.Recv(input_blocks.buffers[from], input_blocks.counts[from], input_blocks.datatypes[from], from, MPI::ANY_TAG);
00060 }
00061 } else {
00062 if (needsToReceive[from]) {
00063 mpi_comm.Recv(input_blocks.buffers[from], input_blocks.counts[from], input_blocks.datatypes[from], from, MPI::ANY_TAG);
00064 }
00065 if (needsToSend[to]) {
00066 mpi_comm.Send(output_blocks.buffers[to], output_blocks.counts[to], output_blocks.datatypes[to], to, tag);
00067 }
00068 }
00069 }
00070 }