% Nested cross-validation for SCANT item selection with lasso regression % Set number of folds for cross-validation numFold_A = 10; % top-level partition: item selection and hold out data numFold_B = 5; % mid-level partition: subject inclusion variability numFold_C = 5; % low-level partition: lasso parameter selection % Load PNT and WAB data load wab_1 load pntC_1 % Identify total number of subjects numSubj = length(wab_1); % Assign each subject to a cross-validation paritition based on nearest neighbor WAB AQ to a quasi-random Sobol sequence SS = sobolset(1,'Skip',1e3); SS = scramble(SS,'MatousekAffineOwen'); test_c2 = false(numSubj,numFold_A+1); wab_idx = sortrows([(1:numSubj)' wab_1],2); ordSubj = zeros(numSubj,1); for ii = 1:numSubj temp = abs(wab_idx(:,2)-(SS(ii)*100)); ordSubj(wab_idx(find(temp==min(temp),1),1)) = ii; test_c2(wab_idx(find(temp==min(temp),1),1),mod(ii,numFold_A+1)+1) = true; wab_idx(find(temp==min(temp),1),2) = inf; end % Censor over-represented subjects who were assigned to partitions last test_c2((ordSubj>220),:) = false; training_c2 = ~test_c2; training_c2(test_c2(:,6),:) = false; training_c2((ordSubj>220),:) = false; % Define variables for iterative short-form construction and assessment rmse_pred = zeros(174,numFold_A); r2_pred = zeros(174,numFold_A); rmse_predB = zeros(174,numFold_A); r2_predB = zeros(174,numFold_A); Bs = zeros(174,numFold_B,numFold_A); ids = zeros(174,174,numFold_A); % Iterative short-form construction and asessment for jj=1:numFold_A % Extract training data from top-level partition wab = wab_1(training_c2(:,jj)); pntC = pntC_1(training_c2(:,jj),:); numSubj = length(wab); % Generate quasi-random CV partition for subject inclusion variability test_c1 = false(numSubj,numFold_B); wab_idx = sortrows([(1:numSubj)' wab],2); for ii = 1:numSubj temp = abs(wab_idx(:,2)-(SS(ii)*100)); test_c1(wab_idx(find(temp==min(temp),1),1),mod(ii,numFold_B)+1) = true; wab_idx(find(temp==min(temp),1),2) = inf; end training_c1 = ~test_c1; % Obtain lasso regression coefficients, balancing model complexity and prediction accuracy for ii = 1:numFold_B [B,FitInfo] = lasso(pntC(training_c1(:,ii),:),wab(training_c1(:,ii)),'CV',numFold_C,'Standardize',false,'Options',statset('UseParallel',true)); Bs(:,ii,jj) = B(:,FitInfo.Index1SE); end % Obtain mean lasso regression coefficients over subject inclusion variability temp = sortrows([(1:174)' mean(Bs(:,:,jj),2)],2); % Predict WAB AQ with increasing numbers of items in the naming test for numItem = 1:174 % Indicate included items ids(numItem,temp(175-numItem:174,1),jj) = 1; % Construct linear regression model predicting WAB AQ from naming stats = regstats(wab,sum(pntC(:,logical(ids(numItem,:,jj))),2),'linear',{'beta'}); % Calulcate cross-validation predictions of WAB AQ wab_hat = [ones(sum(test_c2(:,jj)),1) sum(pntC_1(test_c2(:,jj),logical(ids(numItem,:,jj))),2)]*stats.beta; % Calculate prediction error rmse_pred(numItem,jj) = sqrt(mean((wab_1(test_c2(:,jj))-wab_hat).^2)); r2_pred(numItem,jj) = corr(wab_1(test_c2(:,jj)),wab_hat)^2; end % Plot prediction error by number of included naming items figure(1); hold on; plot(1:174,rmse_pred(:,jj)); drawnow; end % Plot average prediction error over cross-validation folds by number of included naming items figure(1); set(gca,'FontSize',12); plot(1:174,mean(rmse_pred,2),'k-','LineWidth',2); xlabel('Naming Item Set Size','FontSize',16); ylabel('WAB AQ Prediction Error (RMSE)','FontSize',16); legend({'Fold 1','Fold 2','Fold 3','Fold 4','Fold 5','Fold 6','Fold 7','Fold 8','Fold 9','Fold 10','Average'},'FontSize',12); % Index items by average lasso regression coefficients across all cross-validation folds idx = sortrows([(1:174)' mean(mean(Bs(:,:,:),2),3)],2); % Index SCANT items as the set of 20 PNT items with highest coefficients SCANT_idx = idx(155:174,1); % save rmse_pred rmse_pred % save r2_pred r2_pred % save Bs Bs % save test_c1 test_c1 % save training_c1 training_c1 % save training_c2 training_c2 % save test_c2 test_c2 % save ids ids % save ordSubj ordSubj % save idx idx