0001 function [A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 nt_greetings;
0038
0039 if nargin<7||isempty(nodemeanflag); demeanflag=1; end
0040
0041 if ~exist('thresh','var');
0042 thresh=10.^-12;
0043 end
0044
0045 if exist('x','var') && ~isempty(x)
0046
0047 if ~exist('y','var'); error('!'); end
0048 if ~exist('shifts','var')||isempty('shifts'); shifts=[0]; end
0049 if numel(shifts)==1 && shifts==0 && isnumeric(x) && ndims(x)==2;
0050 if demeanflag
0051 x=nt_demean(x);
0052 y=nt_demean(y);
0053 end
0054 C=[x,y]'*[x,y];
0055 m=size(x,2);
0056 else
0057 [C,~,m]=nt_cov_lags(x,y,shifts,demeanflag);
0058 end
0059 [A,B,R]=nt_cca([],[],[],C,m,thresh);
0060
0061 if nargout==0
0062
0063 if length(shifts)>1;
0064 figure(1); clf;
0065 plot(R'); title('correlation for each CC'); xlabel('lag'); ylabel('correlation');
0066 end
0067 end
0068 return
0069 end
0070
0071 if ~exist('C','var') || isempty(C) ; error('!'); end
0072 if ~exist('m','var'); error('!'); end
0073 if size(C,1)~=size(C,2); error('!'); end
0074 if ~isempty(x) || ~isempty(y) || ~isempty(shifts) ; error('!'); end
0075 if ndims(C)>3; error('!'); end
0076
0077 if ndims(C) == 3
0078
0079 N=min(m,size(C,1)-m);
0080 A=zeros(m,N,size(C,3));
0081 B=zeros(size(C,1)-m,N,size(C,3));
0082 R=zeros(N,size(C,3));
0083 for k=1:size(C,3);
0084 [AA,BB,RR]=nt_cca([],[],[],C(:,:,k),m);
0085 A(1:size(AA,1),1:size(AA,2),k)=AA;
0086 B(1:size(BB,1),1:size(BB,2),k)=BB;
0087 R(1:size(RR,2),k)=RR;
0088 end
0089 return;
0090 end
0091
0092
0093
0094
0095
0096
0097 Cx=C(1:m,1:m);
0098 [V, S] = eig(Cx) ;
0099 V=real(V); S=real(S);
0100 [E, idx] = sort(diag(S)', 'descend') ;
0101 keep=find(E/max(E)>thresh);
0102 topcs = V(:,idx(keep));
0103 E = E (keep);
0104 EXP=1-10^-12;
0105 E=E.^EXP;
0106 A1=topcs*diag(sqrt((1./E)));
0107
0108
0109 Cy=C(m+1:end,m+1:end);
0110 [V, S] = eig(Cy) ;
0111 V=real(V); S=real(S);
0112 [E, idx] = sort(diag(S)', 'descend') ;
0113 keep=find(E/max(E)>thresh);
0114 topcs = V(:,idx(keep));
0115 E = E (keep);
0116 E=E.^EXP;
0117 A2=topcs*diag(sqrt((1./E)));
0118
0119
0120 AA=zeros( size(A1,1)+size(A2,1), size(A1,2)+size(A2,2) );
0121 AA( 1:size(A1,1), 1:size(A1,2) )=A1;
0122 AA( size(A1,1)+1:end, size(A1,2)+1:end )=A2;
0123 C= AA' * C * AA;
0124
0125 N=min(size(A1,2),size(A2,2));
0126
0127
0128 [V, S] = eig(C) ;
0129
0130 V=real(V); S=real(S);
0131 [E, idx] = sort(diag(S)', 'descend') ;
0132 topcs = V(:,idx);
0133
0134 A=A1*topcs(1:size(A1,2),1:N)*sqrt(2);
0135 B=A2*topcs(size(A1,2)+1:end,1:N)*sqrt(2);
0136 R=E(1:N)-1;
0137
0138
0139
0140 Why does it work?
0141 If x and y were uncorrelated, eigenvalues E would be all ones.
0142 Correlated dimensions (the canonical correlates) should give values E>1,
0143 i.e. they should map to the first PCs.
0144 To obtain CCs we just select the first N PCs.
0145
0146
0147
0148
0149
0150
0151 if 0
0152
0153 clear
0154 x=randn(10000,20);
0155 y=randn(10000,8);
0156 y(:,1:2)=x(:,1:2);
0157 y(:,3:4)=x(:,3:4)+randn(10000,2);
0158 y(:,5:6)=x(:,5:6)+randn(10000,2)*3;
0159 y(:,7:8)=randn(10000,2);
0160 [A,B,R]=nt_cca(x,y);
0161 figure(1); clf
0162 subplot 321; imagesc(A); title('A');
0163 subplot 322; imagesc(B); title('B');
0164 subplot 323; plot(R, '.-'); title('R')
0165 subplot 324; nt_imagescc((x*A)'*(x*A)); title ('covariance of x*A');
0166 subplot 325; nt_imagescc((y*B)'*(y*B)); title ('covariance of y*B');
0167 subplot 326; nt_imagescc([x*A,y*B]'*[x*A,y*B]); title ('covariance of [x*A,y*B]');
0168 end
0169
0170 if 0
0171
0172 clear
0173 x=randn(1000,11);
0174 y=randn(1000,9);
0175 x=x-repmat(mean(x),size(x,1),1);
0176 y=y-repmat(mean(y),size(y,1),1);
0177 [A1,B1,R1]=canoncorr(x,y);
0178 [A2,B2,R2]=nt_cca(x,y);
0179 A2=A2*sqrt(size(x,1));
0180 B2=B2*sqrt(size(y,1));
0181 figure(1); clf;
0182 subplot 211;
0183 plot([R1' R2']); title('R'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none');
0184 if mean(A1(:,1).*A2(:,1))<0; A2=-A2; end
0185 subplot 212;
0186 plot(([x*A1(:,1),x*A2(:,1)])); title('first component'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none');
0187 figure(2); clf;set(gcf,'defaulttextinterpreter','none')
0188 subplot 121;
0189 nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]); title('canoncorr');
0190 subplot 122;
0191 nt_imagescc([x*A2,y*B2]'*[x*A2,y*B2]); title('nt_cca');
0192 end
0193
0194 if 0
0195
0196 x=randn(100000,100);
0197 tic;
0198 [A,B,R]=nt_cca(x,x);
0199 disp('nt_cca time: ');
0200 toc
0201 [A,B,R]=canoncorr(x,x);
0202 disp('canoncorr time: ');
0203 toc
0204
0205
0206
0207 end
0208
0209 if 0
0210
0211 x=randn(1000,10);
0212 y=randn(1000,10);
0213 y(:,1:3)=x(:,1:3);
0214 shifts=-10:10;
0215 [A1,B1,R1]=nt_cca(x,y,shifts);
0216 figure(1); clf
0217 plot(shifts,R1'); xlabel('lag'); ylabel('R');
0218 end
0219
0220 if 0
0221
0222 x=randn(1000,10);
0223 y=randn(1000,10); y=x(:,randperm(10));
0224 [A1,B1,R1]=nt_cca(x,y);
0225 figure(1); clf
0226 nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]);
0227 end
0228
0229 if 0
0230
0231 x=randn(1000,10);
0232 y=randn(1000,10);
0233 xx={x,x,x}; yy={x,y,y};
0234 [A,B,R]=nt_cca(xx,yy);
0235 disp('seems to work...');
0236 end
0237
0238
0239