00001 #include "MPIInputBuffer.h"
00002
00003 #include "PCSIMException.h"
00004
00005 #include <cassert>
00006
00007 #include <algorithm>
00008
00009 #include <cstring>
00010
00011 using std::min;
00012
00013
00014 MPIInputBuffer::MPIInputBuffer(int nEngines)
00015 : slicer(MPIBufferSlicer::tpInputBuffer)
00016 {
00017
00018 analogMsgCounters.resize(nEngines, 0);
00019 initialized = false;
00020 }
00021
00022 MPIInputBuffer::~MPIInputBuffer()
00023 {
00024 if (initialized)
00025 if (slicer.thereIsMixedDataType)
00026 mixedMPIDataType.Free();
00027 }
00028
00029 void MPIInputBuffer::initialize(MPIMessageSpec msgSpec, size_t spikeBufferSize, size_t maxMPIMsgSize,
00030 void *baseBufferPtr, void *analogBuffer, void *spikeBuffer)
00031 {
00032 this->baseBufferPtr = baseBufferPtr;
00033 currentMsgInfo = msgSpec;
00034 mpiInputSpikeBuffer.initialize((MPIInputSpikeBuffer<>::coding_element_type *)spikeBuffer);
00035
00036 analog_buf = analogBuffer;
00037 spike_buf = spikeBuffer;
00038
00039 size_t analog_buffer_size_elements = (double *)spikeBuffer - (double *)analogBuffer;
00040
00041 spike_buffer_size_elements = spikeBufferSize / sizeof(MPIInputSpikeBuffer<>::coding_element_type);
00042
00043 slicer.initialize(analog_buffer_size_elements,
00044 maxMPIMsgSize, spikeBufferSize);
00045
00046 analogMPIDatatype = MPI::DOUBLE;
00047 if (sizeof(MPIInputSpikeBuffer<>::coding_element_type) == 4)
00048 spikingMPIDatatype = MPI::LONG;
00049 else
00050 spikingMPIDatatype = MPI::SHORT;
00051
00052
00053 if (slicer.thereIsMixedDataType) {
00054
00055 MPI::Datatype mixed_data_types[2];
00056 mixed_data_types[0] = analogMPIDatatype;
00057 mixed_data_types[1] = spikingMPIDatatype;
00058
00059 if (maxMPIMsgSize == 0)
00060 mixedCounts[0] = analog_buffer_size_elements;
00061 else
00062 mixedCounts[0] = analog_buffer_size_elements % (maxMPIMsgSize / sizeof(double));
00063 mixedCounts[1] = spikeBufferSize / sizeof(MPIInputSpikeBuffer<>::coding_element_type);
00064
00065 mixedDisplacements[0] = 0;
00066 mixedDisplacements[1] = mixedCounts[0] * sizeof(double);
00067
00068 mixedMPIDataType = MPI::Datatype::Create_struct(2, mixedCounts, mixedDisplacements, mixed_data_types);
00069
00070 mixedMPIDataType.Commit();
00071 mixedMsgAbsoluteDisplacement = (char *)( (double *)spike_buf - mixedCounts[0] ) - (char *)baseBufferPtr;
00072 }
00073 currentMsgInfo.hasContent = true;
00074 initialized = true;
00075 }
00076
00077
00078 void MPIInputBuffer::startNewMPIExchange()
00079 {
00080 slicer.reset();
00081 }
00082
00083 bool MPIInputBuffer::hasNextBufferSlice()
00084 {
00085 if (slicer.currentSliceType == MPIBufferSlicer::sliceAnalog ||
00086 slicer.currentSliceType == MPIBufferSlicer::sliceUndefined ) {
00087 if (totalAnalogMsgCounter)
00088 return true;
00089 else
00090 return !mpiInputSpikeBuffer.isLastReceivedBuffer();
00091 }
00092 return !mpiInputSpikeBuffer.isLastReceivedBuffer();
00093 }
00094
00095 MPIMessageSpec & MPIInputBuffer::prepareNextBufferSlice()
00096 {
00097 slicer.calcNextBufferSliceDimensions();
00098
00099
00100 switch (slicer.currentSliceType) {
00101 case MPIBufferSlicer::sliceAnalog :
00102 *currentMsgInfo.displacement = ((char *)analog_buf + slicer.currentSlicePos) - (char *)baseBufferPtr;
00103 *currentMsgInfo.buffer = (char *)analog_buf + slicer.currentSlicePos;
00104 *currentMsgInfo.count = slicer.currentAnalogSliceSize / sizeof(double);
00105 *currentMsgInfo.datatype = analogMPIDatatype;
00106 currentMsgInfo.content_type = MPIMessageSpec::contentAnalog;
00107 break;
00108 case MPIBufferSlicer::sliceMixed :
00109 *currentMsgInfo.displacement = mixedMsgAbsoluteDisplacement;
00110 *currentMsgInfo.buffer = (char *)baseBufferPtr + mixedMsgAbsoluteDisplacement;
00111 *currentMsgInfo.count = 1;
00112 *currentMsgInfo.datatype = mixedMPIDataType;
00113 currentMsgInfo.content_type = MPIMessageSpec::contentMixed;
00114 break;
00115 case MPIBufferSlicer::sliceSpiking :
00116 *currentMsgInfo.displacement = (char *)spike_buf - (char *)baseBufferPtr;
00117 *currentMsgInfo.buffer = spike_buf;
00118 *currentMsgInfo.count = spike_buffer_size_elements;
00119 *currentMsgInfo.datatype = spikingMPIDatatype;
00120 currentMsgInfo.content_type = MPIMessageSpec::contentSpiking;
00121 break;
00122 case MPIBufferSlicer::sliceUndefined :
00123 assert( 0 );
00124 break;
00125 }
00126 return currentMsgInfo;
00127 }
00128
00129
00130 MPIInputBufferVector::MPIInputBufferVector(vector< vector< gl_engineid_t> > & glengineids, int numBuffers)
00131 : initialized(false), nNodes(numBuffers), mpiExchBlocksInfo(numBuffers)
00132 {
00133 for (int i = 0; i < nNodes ; ++i) {
00134 _buffers.push_back(MPIInputBuffer(glengineids[i].size()));
00135 }
00136 }
00137
00138
00139 void MPIInputBufferVector::initialize(int minDelay, size_t maxMPIMessageSize, size_t spikeBufferSize)
00140 {
00141
00142 max_mpi_msg_size = maxMPIMessageSize;
00143 spike_buffer_size = spikeBufferSize;
00144
00145 int pool_size = 0;
00146 for (int i = 0; i < nNodes; ++i) {
00147 _buffers[i].calculateTotalAnalogMsgCounter();
00148
00149 pool_size += _buffers[i].totalAnalogMsgCounter * sizeof(double) * minDelay + spike_buffer_size;
00150 }
00151
00152 memoryPool = new char[pool_size];
00153
00154 memset(memoryPool, 0, pool_size);
00155
00156
00157 void * analogBufPtr = (void *)memoryPool;
00158 for (int i = 0; i < nNodes; ++i) {
00159 void *spikeBufPtr = (void *)((char *)analogBufPtr + _buffers[i].totalAnalogMsgCounter * sizeof(double) * minDelay);
00160 _buffers[i].initialize(mpiExchBlocksInfo.getMsgSpec(i), spike_buffer_size, max_mpi_msg_size, memoryPool, analogBufPtr, spikeBufPtr);
00161 analogBufPtr = (char *)analogBufPtr + _buffers[i].totalAnalogMsgCounter * sizeof(double) * minDelay + spike_buffer_size;
00162 }
00163
00164 initialized = true;
00165 }
00166
00167
00168 MPIExchangeBlocksInfo & MPIInputBufferVector::getMPIExchangeBlocksInfo()
00169 {
00170 return mpiExchBlocksInfo;
00171 }
00172