// Copyright (c) Facebook, Inc. and its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

import (
	"context"
	"encoding/json"
	"fmt"
	"log"

	"github.com/facebookincubator/nvdtools/providers/fireeye/schema"
	"github.com/facebookincubator/nvdtools/providers/lib/client"
	"github.com/facebookincubator/nvdtools/providers/lib/runner"
	"github.com/facebookincubator/nvdtools/stats"
	"golang.org/x/sync/errgroup"
)

// FetchAllVulnerabilities will fetch all vulnerabilities with specified parameters
func (c *Client) FetchAllVulnerabilities(ctx context.Context, since int64) (<-chan runner.Convertible, error) {
	parameters := newParametersSince(since)
	if err := parameters.validate(); err != nil {
		return nil, err
	}

	output := make(chan runner.Convertible)

	eg, ctx := errgroup.WithContext(ctx)
	for _, params := range parameters.batchBy(ninetyDays) {
		params := params
		eg.Go(func() error {
			log.Printf("Fetching: %s\n", params)
			vs, err := c.fetchVulnerabilities(ctx, params)
			if err != nil {
				return client.StopOrContinue(fmt.Errorf("error while fetching %s: %v", params, err))
			}
			numVulns := len(vs)
			log.Printf("Adding %d vulns\n", numVulns)
			stats.IncrementCounterBy("vulnerabilities", int64(numVulns))
			for _, v := range vs {
				output <- v
			}
			return nil
		})
	}

	go func() {
		if err := eg.Wait(); err != nil {
			log.Println(err)
		}
		close(output)
	}()

	return output, nil
}

func (c *Client) fetchVulnerabilities(ctx context.Context, parameters timeRangeParameters) ([]*schema.Vulnerability, error) {
	resp, err := c.Request(ctx, fmt.Sprintf("/view/vulnerability?%s", parameters.query()))
	if err != nil {
		return nil, err
	}

	var vulnerabilities []*schema.Vulnerability
	if err := json.NewDecoder(resp).Decode(&vulnerabilities); err != nil {
		return nil, err
	}

	return vulnerabilities, nil
}
