Home > NoiseTools > nt_cca.m

nt_cca

PURPOSE ^

[A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag) - canonical correlation

SYNOPSIS ^

function [A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag)

DESCRIPTION ^

[A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag) - canonical correlation

  A, B: transform matrices
  R: r scores

  x,y: column matrices
  shifts: positive lag means y delayed relative to x
  C: covariance matrix of [x, y]
  m: number of columns of x
  thresh: discard PCs below this 
  demeanflag: if true remove means [default: true]

  Usage 1:
   [A,B,R]=nt_cca(x,y); % CCA of x, y

  Usage 2: 
   [A,B,R]=nt_cca(x,y,shifts); % CCA of x, y for each value of shifts.
   A positive shift indicates that y is delayed relative to x.

  Usage 3:
   C=[x,y]'*[x,y]; % covariance
   [A,B,R]=nt_cca([],[],[],C,size(x,2)); % CCA of x,y

 Use the third form to handle multiple files or large data
 (covariance C can be calculated chunk-by-chunk). 

 C can be 3-D, which case CCA is derived independently from each page.

 Warning: means of x and y are NOT removed.
 Warning: A, B scaled so that (x*A)^2 and (y*B)^2 are identity matrices (differs from canoncorr).

 See nt_cov_lags, nt_relshift, nt_cov, nt_pca.

 NoiseTools.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag)
0002 %[A,B,R]=nt_cca(x,y,shifts,C,m,thresh,demeanflag) - canonical correlation
0003 %
0004 %  A, B: transform matrices
0005 %  R: r scores
0006 %
0007 %  x,y: column matrices
0008 %  shifts: positive lag means y delayed relative to x
0009 %  C: covariance matrix of [x, y]
0010 %  m: number of columns of x
0011 %  thresh: discard PCs below this
0012 %  demeanflag: if true remove means [default: true]
0013 %
0014 %  Usage 1:
0015 %   [A,B,R]=nt_cca(x,y); % CCA of x, y
0016 %
0017 %  Usage 2:
0018 %   [A,B,R]=nt_cca(x,y,shifts); % CCA of x, y for each value of shifts.
0019 %   A positive shift indicates that y is delayed relative to x.
0020 %
0021 %  Usage 3:
0022 %   C=[x,y]'*[x,y]; % covariance
0023 %   [A,B,R]=nt_cca([],[],[],C,size(x,2)); % CCA of x,y
0024 %
0025 % Use the third form to handle multiple files or large data
0026 % (covariance C can be calculated chunk-by-chunk).
0027 %
0028 % C can be 3-D, which case CCA is derived independently from each page.
0029 %
0030 % Warning: means of x and y are NOT removed.
0031 % Warning: A, B scaled so that (x*A)^2 and (y*B)^2 are identity matrices (differs from canoncorr).
0032 %
0033 % See nt_cov_lags, nt_relshift, nt_cov, nt_pca.
0034 %
0035 % NoiseTools.
0036 
0037 nt_greetings; 
0038 
0039 if nargin<7||isempty(nodemeanflag); demeanflag=1; end
0040 
0041 if ~exist('thresh','var');
0042     thresh=10.^-12; 
0043 end
0044 
0045 if exist('x','var') && ~isempty(x)
0046     % Calculate covariance of [x,y]
0047     if ~exist('y','var'); error('!'); end
0048     if ~exist('shifts','var')||isempty('shifts'); shifts=[0]; end
0049     if numel(shifts)==1 && shifts==0 && isnumeric(x) && ndims(x)==2; 
0050         if demeanflag
0051             x=nt_demean(x);
0052             y=nt_demean(y);
0053         end
0054         C=[x,y]'*[x,y]; % simple case
0055         m=size(x,2); 
0056     else        
0057         [C,~,m]=nt_cov_lags(x,y,shifts,demeanflag); % lags, multiple trials, etc.
0058     end
0059     [A,B,R]=nt_cca([],[],[],C,m,thresh);
0060     
0061     if nargout==0 
0062         % plot something nice
0063         if length(shifts)>1;
0064             figure(1); clf;
0065             plot(R'); title('correlation for each CC'); xlabel('lag'); ylabel('correlation');
0066         end
0067      end
0068     return
0069 end % else keep going
0070 
0071 if ~exist('C','var') || isempty(C) ; error('!'); end
0072 if ~exist('m','var'); error('!'); end
0073 if size(C,1)~=size(C,2); error('!'); end
0074 if ~isempty(x) || ~isempty(y) || ~isempty(shifts)  ; error('!'); end
0075 if ndims(C)>3; error('!'); end
0076 
0077 if ndims(C) == 3
0078     % covariance is 3D: do a separate CCA for each page
0079     N=min(m,size(C,1)-m); % note that for some pages there may be fewer than N CCs
0080     A=zeros(m,N,size(C,3));
0081     B=zeros(size(C,1)-m,N,size(C,3));
0082     R=zeros(N,size(C,3));
0083     for k=1:size(C,3);
0084         [AA,BB,RR]=nt_cca([],[],[],C(:,:,k),m);
0085         A(1:size(AA,1),1:size(AA,2),k)=AA;
0086         B(1:size(BB,1),1:size(BB,2),k)=BB;
0087         R(1:size(RR,2),k)=RR;
0088     end
0089     return;
0090 end % else keep going
0091 
0092 
0093 %%
0094 % Calculate CCA given C=[x,y]'*[x,y] and m=size(x,2);
0095 
0096 % sphere x
0097 Cx=C(1:m,1:m);
0098 [V, S] = eig(Cx) ;  
0099 V=real(V); S=real(S);
0100 [E, idx] = sort(diag(S)', 'descend') ;
0101 keep=find(E/max(E)>thresh);
0102 topcs = V(:,idx(keep));
0103 E = E (keep);
0104 EXP=1-10^-12; 
0105 E=E.^EXP; % break symmetry when x and y perfectly correlated (otherwise cols of x*A and y*B are not orthogonal)
0106 A1=topcs*diag(sqrt((1./E)));
0107 
0108 % sphere y
0109 Cy=C(m+1:end,m+1:end);
0110 [V, S] = eig(Cy) ;  
0111 V=real(V); S=real(S);
0112 [E, idx] = sort(diag(S)', 'descend') ;
0113 keep=find(E/max(E)>thresh);
0114 topcs = V(:,idx(keep));
0115 E = E (keep);
0116 E=E.^EXP; %
0117 A2=topcs*diag(sqrt((1./E)));
0118 
0119 % apply sphering matrices to C
0120 AA=zeros( size(A1,1)+size(A2,1), size(A1,2)+size(A2,2) );
0121 AA( 1:size(A1,1), 1:size(A1,2) )=A1;
0122 AA( size(A1,1)+1:end, size(A1,2)+1:end )=A2;
0123 C= AA' * C * AA;
0124 
0125 N=min(size(A1,2),size(A2,2)); % number of canonical components
0126 
0127 % PCA
0128 [V, S] = eig(C) ;
0129 %[V, S] = eigs(C,N) ; % not faster
0130 V=real(V); S=real(S);
0131 [E, idx] = sort(diag(S)', 'descend') ;
0132 topcs = V(:,idx);
0133 
0134 A=A1*topcs(1:size(A1,2),1:N)*sqrt(2);  % why sqrt(2)?...
0135 B=A2*topcs(size(A1,2)+1:end,1:N)*sqrt(2);
0136 R=E(1:N)-1; 
0137 
0138 
0139 %{
0140 Why does it work?
0141 If x and y were uncorrelated, eigenvalues E would be all ones. 
0142 Correlated dimensions (the canonical correlates) should give values E>1, 
0143 i.e. they should map to the first PCs. 
0144 To obtain CCs we just select the first N PCs. 
0145 %}
0146 
0147 %%
0148 
0149 %%
0150 % test code
0151 if 0
0152     % basic
0153     clear
0154     x=randn(10000,20);
0155     y=randn(10000,8);
0156     y(:,1:2)=x(:,1:2); % perfectly correlated
0157     y(:,3:4)=x(:,3:4)+randn(10000,2); % 1/2 correlated
0158     y(:,5:6)=x(:,5:6)+randn(10000,2)*3; % 1/4 correlated
0159     y(:,7:8)=randn(10000,2); % uncorrelated
0160     [A,B,R]=nt_cca(x,y);
0161     figure(1); clf
0162     subplot 321; imagesc(A); title('A');
0163     subplot 322; imagesc(B); title('B');
0164     subplot 323; plot(R, '.-'); title('R')
0165     subplot 324; nt_imagescc((x*A)'*(x*A)); title ('covariance of x*A');
0166     subplot 325; nt_imagescc((y*B)'*(y*B)); title ('covariance of y*B');
0167     subplot 326; nt_imagescc([x*A,y*B]'*[x*A,y*B]); title ('covariance of [x*A,y*B]');
0168 end
0169 
0170 if 0 
0171     % compare with canoncorr
0172     clear
0173     x=randn(1000,11);
0174     y=randn(1000,9);
0175     x=x-repmat(mean(x),size(x,1),1); % center, otherwise result may differ slightly from canoncorr
0176     y=y-repmat(mean(y),size(y,1),1);
0177     [A1,B1,R1]=canoncorr(x,y);
0178     [A2,B2,R2]=nt_cca(x,y);   
0179     A2=A2*sqrt(size(x,1)); % scale like canoncorr
0180     B2=B2*sqrt(size(y,1));
0181     figure(1); clf; 
0182     subplot 211; 
0183     plot([R1' R2']); title('R'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none'); 
0184     if mean(A1(:,1).*A2(:,1))<0; A2=-A2; end
0185     subplot 212; 
0186     plot(([x*A1(:,1),x*A2(:,1)])); title('first component'); legend({'canoncorr', 'nt_cca'}, 'Interpreter','none'); 
0187     figure(2); clf;set(gcf,'defaulttextinterpreter','none')
0188     subplot 121; 
0189     nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]); title('canoncorr'); 
0190     subplot 122; 
0191     nt_imagescc([x*A2,y*B2]'*[x*A2,y*B2]); title('nt_cca');
0192 end
0193 
0194 if 0
0195     % time
0196     x=randn(100000,100); 
0197     tic; 
0198     [A,B,R]=nt_cca(x,x); 
0199     disp('nt_cca time: ');
0200     toc    
0201     [A,B,R]=canoncorr(x,x); 
0202     disp('canoncorr time: ');
0203     toc
0204 %     [A,B,R]=cca(x,x);
0205 %     disp('cca time: ');
0206 %     toc
0207 end
0208 
0209 if 0
0210     % shifts
0211     x=randn(1000,10);
0212     y=randn(1000,10);
0213     y(:,1:3)=x(:,1:3);
0214     shifts=-10:10;
0215     [A1,B1,R1]=nt_cca(x,y,shifts);
0216     figure(1); clf
0217     plot(shifts,R1'); xlabel('lag'); ylabel('R');
0218 end
0219 
0220 if 0
0221     % what happens if x & y perfectly correlated?
0222     x=randn(1000,10);
0223     y=randn(1000,10); y=x(:,randperm(10)); %+0.000001*y;
0224     [A1,B1,R1]=nt_cca(x,y);
0225     figure(1); clf
0226     nt_imagescc([x*A1,y*B1]'*[x*A1,y*B1]);
0227 end    
0228 
0229 if 0
0230     % x and y are cell arrays
0231     x=randn(1000,10); 
0232     y=randn(1000,10);
0233     xx={x,x,x};  yy={x,y,y};
0234     [A,B,R]=nt_cca(xx,yy);
0235     disp('seems to work...');
0236 end
0237 
0238     
0239

Generated on Tue 18-Feb-2020 11:23:12 by m2html © 2005