%Ekta Gujral,Ravdeep Pasricha and Evangelos E. Papalexakis
%Department of Computer Science and Engineering, University of California Riverside 
%"OCTen:Online Compression-based Tensor Decomposition", Submitted in PKDD 2018
function [A,B,C,summaries,memUsed]=OCTen_update(Xnew,summaries_old,ops,U,V)
       sz=size(Xnew);
       I=sz(1);J=sz(2);K=sz(3);
     
         W=randn(K,ops.Q,ops.p);
     for pp=2:ops.p
          W(:,1:ops.shared,pp)=W(:,1:ops.shared,1); 
    end
      % Takes up more memory but faster:
    parfor pp=1:ops.p 
    %     fprintf('Now computing compressed cube number %d\n',pp);
        G1=reshape(double(Xnew),I,J*K);
        G2=reshape(U(:,:,pp).'*G1,ops.Q*J,K).'; 
        G3=reshape(W(:,:,pp).'*G2,ops.Q*ops.Q,J).'; 
        G4=reshape(V(:,:,pp).'*G3,ops.Q*ops.Q,ops.Q).'; 
        summaries{pp}=reshape(G4,ops.Q,ops.Q,ops.Q); 
       
      for q=1:ops.Q
              if(q<round(ops.Q/7))
                summaries{pp}(:,:,q)= summaries{pp}(:,:,q)+ summaries_old{pp}(:,:,q);
              else
                summaries{pp}(:,:,q)=  summaries{pp}(:,:,q);
              end
      end
       
        [As(:,:,pp),Bs(:,:,pp),Cs(:,:,pp),~]=comfac(summaries{pp}, ops.R,2);
         
    end
    clear G1; clear G2;clear G3;clear G4;clear Xnew; clear summaries_old;clear Xs;% not needed anymore
 AsTop=As(1:ops.shared,:,:); 
 BsTop=Bs(1:ops.shared,:,:); 
 CsTop=Cs(1:ops.shared,:,:); 
  for pp=1:ops.p
        for f=1:ops.R
            [~,mloc]=max(abs(AsTop(:,f,pp)));
            AsTop(:,f,pp)=AsTop(:,f,pp)/AsTop(mloc,f,pp);
            [~,mloc]=max(abs(BsTop(:,f,pp)));
            BsTop(:,f,pp)=BsTop(:,f,pp)/BsTop(mloc,f,pp);
            [~,mloc]=max(abs(CsTop(:,f,pp)));
            CsTop(:,f,pp)=CsTop(:,f,pp)/CsTop(mloc,f,pp);
            
        end
  end
   
  for pp=2:ops.p 
    % permute columns of each AtildeTop(:,:,p) to match AtildeTop(:,:,1)
  
    [AsTop(:,:,pp),PermpA] = perm2match(AsTop(:,:,pp),AsTop(:,:,1));
    [BsTop(:,:,pp),PermpB] = perm2match(BsTop(:,:,pp),BsTop(:,:,1));
    [CsTop(:,:,pp),PermpC] = perm2match(CsTop(:,:,pp),CsTop(:,:,1));
     
    % Now permute columns of Atilde(:,:,p) to match the perm in Atilde(:,:,1)
    As(:,:,pp) = As(:,:,pp)*PermpA;
    Bs(:,:,pp) = Bs(:,:,pp)*PermpB;
    Cs(:,:,pp) = Cs(:,:,pp)*PermpC;
    
        
    % now equalize matched column scales by looking at Atilde(1:S,:,:) 
    % which are common: we pick the first row for this here: 
      
    LambdapA = diag(As(1,:,1)./As(1,:,pp));
    LambdapB = diag(Bs(1,:,1)./Bs(1,:,pp));
    LambdapC = diag(Cs(1,:,1)./Cs(1,:,pp));
    
    % apply to the full Atilde(:,:,p):
    As(:,:,pp) = As(:,:,pp)*LambdapA;
    Bs(:,:,pp) = Bs(:,:,pp)*LambdapB;
    Cs(:,:,pp) = Cs(:,:,pp)*LambdapC;
 
  end
  
    clear AsTop; clear BsTop;clear CsTop; clear LambdapA; clear LambdapB; clear LambdapC; % no need now
   
    %stack all matrices here
    Q=ops.Q; 
    allAs = zeros(Q+(Q-ops.shared)*(ops.p-1),ops.R);
    allUTA=zeros(Q+(Q-ops.shared)*(ops.p-1),I);
     allAs(1:Q,:)=As(:,:,1);
     allUTA(1:Q,:)=U(:,:,1)';

   
    for pp=2:ops.p 
        allAs(Q+(Q-ops.shared)*(pp -2)+1:Q+(Q-ops.shared)*(pp-1),:)=As(ops.shared+1:Q,:,pp);
        allUTA(Q+(Q-ops.shared)*(pp -2)+1:Q+(Q-ops.shared)*(pp-1),:)=U(:,ops.shared+1:Q,pp)';
    end

    allBs = zeros(Q+(Q-ops.shared)*(ops.p-1),ops.R);
    allUTB=zeros(Q+(Q-ops.shared)*(ops.p-1),J);
    allBs(1:Q,:)=Bs(:,:,1);
    allUTB(1:Q,:)=V(:,:,1)';
    for pp=2:ops.p 
        allBs(Q+(Q-ops.shared)*(pp -2)+1:Q+(Q-ops.shared)*(pp-1),:)=Bs(ops.shared+1:Q,:,pp);
        allUTB(Q+(Q-ops.shared)*(pp -2)+1:Q+(Q-ops.shared)*(pp-1),:)=V(:,ops.shared+1:Q,pp)';
    end
    allCs = zeros(Q+(Q-ops.shared)*(ops.p-1),ops.R);
    allUTC=zeros(Q+(Q-ops.shared)*(ops.p-1),K);
    allCs(1:Q,:)=Cs(:,:,1);
    allUTC(1:Q,:)=W(:,:,1)';
    for pp=2:ops.p 
        allCs(Q+(Q-ops.shared)*(pp-2)+1:Q+(Q-ops.shared)*(pp-1),:)=Cs(ops.shared+1:Q,:,pp);
        allUTC(Q+(Q-ops.shared)*(pp-2)+1:Q+(Q-ops.shared)*(pp-1),:)=W(:,ops.shared+1:Q,pp)';
    end
    p=ops.p;
    clear ops;clear pp;clear Q;
   
  %solve for the full size decomposition using least squares
    A=pinv(allUTA)*allAs; 
    B=pinv(allUTB)*allBs;
    C=pinv(allUTC)*allCs;   
    

 mem_elements = eval('whos');
 memory_array=zeros(size(mem_elements,1),1);
if size(mem_elements,1) > 0 
		for i = 1:size(mem_elements,1) 
                if strcmp(mem_elements(i).name,'summaries')==1
                    mem_elements(i).bytes=mem_elements(i).bytes/p;
                end
				memory_array(i) = mem_elements(i).bytes;
        end
        memUsed = sum(memory_array);
else
        memUsed = 0;
end
end