constraint_app
|
00001 #include "mex.h" 00002 #include "ConstraintApp.h" 00003 00017 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) 00018 { 00019 if ( nrhs != 3 ) { 00020 mexErrMsgTxt("requires exactly 3 parameters: (ptr, states, ids)"); 00021 } 00022 00023 if ( nlhs != 1 ) { 00024 mexErrMsgTxt("returns exactly 1 parameters"); 00025 } 00026 00027 ConstraintApp* app(*(ConstraintApp**)mxGetData(prhs[0])); 00028 00029 int stateSize = app->GetStateSize(); 00030 int stateCount = mxGetN(prhs[1]); 00031 if ( mxGetM(prhs[1]) != stateSize ) { 00032 mexErrMsgTxt("state must be SxN, where S is the size of the state"); 00033 } 00034 00035 int idSize = mxGetNumberOfElements(prhs[2]); 00036 00037 std::vector<std::string> ids(idSize); 00038 const mxArray* idCellArray = prhs[2]; 00039 for ( int i = 0; i < idSize; i++ ) { 00040 const mxArray* string_array = mxGetCell(idCellArray, i); 00041 int buflen = mxGetNumberOfElements(string_array) + 1; 00042 char* str = new char[buflen]; 00043 mxGetString(string_array, str, buflen); 00044 //mexPrintf("str: [%s] %i\n", str, buflen); 00045 ids[i] = std::string(str); 00046 delete str; 00047 } 00048 00049 int numObs = ids.size()*3; 00050 plhs[0] = mxCreateDoubleMatrix(numObs, stateCount, mxREAL); 00051 double* o((double*)mxGetData(plhs[0])); 00052 00053 double* s((double*)mxGetData(prhs[1])); 00054 for ( int i = 0; i < stateCount; i++ ) { 00055 std::vector<double> state(stateSize), obs; 00056 std::copy(s+i*stateSize, s+i*stateSize+stateSize, state.begin()); 00057 if ( !app->GetExpectedObservations(state, ids, obs) ) { 00058 mexErrMsgTxt("unable to find a link"); 00059 } 00060 std::copy(obs.begin(), obs.end(), o+numObs*i); 00061 } 00062 }