%% inspired by xue mei, haibin ling: Robust Visual Tracking using l1 Minimization

folder = '/home/sulcanto/animal_behavior/data/ants_tracking_scaffolding/Antonin_Ants_Scaffolding/Videos/80/'
files = dir([folder '*.mov'])
min_area = 200;
max_area = 25000;

i_file = 1;
subsample = 4;
min_area = min_area / subsample;
max_area = max_area / subsample;


t_frames = 300;
lambda_dict = 0.15;
lambda_code = 0.55;
K = 128;
%{
lambda_code = 0.4;
K = 16;
%}
window_type = 'sliding';

for i_file = 1:length(files)

    in = VideoReader([folder files(i_file).name]);
   
    mkdir([folder '/res/sparse_coding/' files(i_file).name(1:end-4)])
    mask = imread([folder files(i_file).name(1:end-3) 'png']);
    mask = (mask(:,:,1) == 255) & mask(:,:,2) == 0 & mask(:,:,3) == 0;

    
    H = in.Height / subsample;
    W = in.Width / subsample;
    I = zeros(H,W,3,t_frames,'single');


    % extracting sequence of t_seq frames
    mask = imresize(mask,[H,W]);
    for f=1:t_frames
        I(:,:,:,f) = imresize(im2single(readFrame(in)),[H,W]);

        fprintf('taking frame %i of %i\n',f,t_frames)
    end
    % mean frame (background extraction)
    m = mean(I,4);

    %
    models = cell(t_frames,1);
    for f=1:size(I,4)

        img = sum(abs(I(:,:,:,f)- m),3);
        img_m = mask .* img;
        img_m = img_m > 5*mean( img_m ( img_m (:) > 0) );

        img_m = imopen(img_m, strel('rectangle', [3,3]));
        img_m = imclose(img_m, strel('rectangle', [10, 10]));
        img_m = imfill(img_m, 'holes');

        bboxes = regionprops(img_m,'BoundingBox','Area'); 
        bboxes = bboxes([bboxes.Area] >= min_area & [bboxes.Area] <= max_area);
        bboxes = {bboxes.BoundingBox};
        
%         imagesc(img); axis off; axis image;
        models{f} = cell(length(bboxes),1);
        for i=1:length(bboxes)
            b = uint16(bboxes{i});
            patch = (I(b(2):b(2) + b(4),b(1):b(1) + b(3),:,f));
            models{f}{i} = rgb2gray(patch);
        end
        fprintf('object detection (%i,%i), found %i \n',f,size(I,4),length(models{f}))
    end

    %%
    n_models = arrayfun(@(i) length(models{i}),1:t_frames);
    Ws = [];
    Hs = [];
    for i=1:length(models)
        for j=1:length(models{i})
            Hs(end+1) = size(models{i}{j},1);
            Ws(end+1) = size(models{i}{j},2);
        end
    end

    Wo = median(Ws);
    Ho = median(Hs);
    lin = @(v) v(:);
    X = [];

    for i=1:length(models)
        for j=1:length(models{i})
            X = cat(2,X,lin(imresize(models{i}{j},[Ho,Wo])));
        end
    end

    X=X-repmat(mean(X),[size(X,1) 1]);
    X=X ./ repmat(sqrt(sum(X.^2)),[size(X,1) 1]);
    %
    param.K=128;  % learns a dictionary with 100 elements
    param.lambda=lambda_dict;
    param.numThreads=-1; % number of threads
    param.batchsize=400;
    param.verbose=false;
    param.pos = true;
    param.iter=1000;  % let us see what happens after 1000 iterations.
    fprintf('dictionary creation %i elems, lambda %f',K,lambda_dict)
    tic
    D = mexTrainDL(X,param);
    %{
    D = pca(X');
    D = D(:,1:K);
    %}
    %%
    
    in = VideoReader([folder files(i_file).name]);
    out = VideoWriter([folder '/res/sparse_coding/',files(i_file).name(1:end-3) 'avi'])
    open(out)
    NF = uint16(in.Duration * in.FrameRate);
    
    frame = 1;
    while hasFrame(in)
        tic;
        img = im2single(imresize(rgb2gray(readFrame(in)),[H,W]));
        X = im2col(img,[Ho, Wo],window_type);
        X=X-repmat(mean(X),[size(X,1) 1]);
        X=X./ repmat(sqrt(sum(X.^2)),[size(X,1) 1]);
        
        param.lambda = lambda_code;
        param.K = K;
        alpha=mexLasso(X,D,param);
        alpha_f = full(alpha);
        
        V = zeros((H-Ho+1),(W-Wo+1),param.K);
        for i=1:param.K
            V(:,:,i) = col2im(full(alpha_f(i,:)),[Ho,Wo],[H,W],window_type);
        end
        
        img = mat2gray(... 
            cat(2,mask .* img,mask.* imresize(mat2gray(sum(V,3)),[H,W])));
        
        imagesc(img); pause(0.01); axis image; colormap gray;
        
        writeVideo(out,img)
        
        save([folder '/res/sparse_coding/' files(i_file).name(1:end-4) '/' num2str(frame) '.mat'],...
           'alpha','H','W','Ho','Wo','img','D')
       
        
        fprintf('\n %s took %f (%i/%i)',files(i_file).name,toc,frame,NF)
        frame = frame + 1;
    end


end