diff --git a/go.mod b/go.mod index 9e0f3e5..32e88a7 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module sst.rievo.dev/go-abuseipdb go 1.22 + +require github.com/allegro/bigcache/v3 v3.1.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..25ea6c5 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/allegro/bigcache/v3 v3.1.0 h1:H2Vp8VOvxcrB91o86fUSVJFqeuz8kpyyB02eH3bSzwk= +github.com/allegro/bigcache/v3 v3.1.0/go.mod h1:aPyh7jEvrog9zAwx5N7+JUQX5dZTSGpxF1LAR4dr35I= diff --git a/pkg/abuseipdb/client.go b/pkg/abuseipdb/client.go index 0426ebf..308e18b 100644 --- a/pkg/abuseipdb/client.go +++ b/pkg/abuseipdb/client.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/allegro/bigcache/v3" "html" "io" "mime/multipart" @@ -35,8 +36,11 @@ type Client struct { // ratelimit RateLimit *rate.Limit + + cache *bigcache.BigCache } +// NewClient pass nil to not use a custom *http.Client func NewClient(client *http.Client) *Client { if client == nil { client = &http.Client{} @@ -51,6 +55,11 @@ func (c *Client) AddApiKey(key string) { c.APIKey = key } +// AddCache with "ttl", suggestion 30min +func (c *Client) AddCache(eviction time.Duration) { + c.cache, _ = bigcache.New(context.Background(), bigcache.DefaultConfig(eviction)) +} + func (c *Client) initialize() { if c.client == nil { c.client = &http.Client{} @@ -63,7 +72,7 @@ func (c *Client) initialize() { } } -func (c *Client) NewRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) { +func (c *Client) newRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) { if !strings.HasSuffix(c.BaseURL.Path, "/") { return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL) } @@ -111,7 +120,7 @@ func handleError(r *http.Response) error { return nil } -func (c *Client) Do(ctx context.Context, req *http.Request, v any) error { +func (c *Client) do(ctx context.Context, req *http.Request, v any) error { resp, err := c.client.Do(req.WithContext(ctx)) if err != nil { return err @@ -169,8 +178,9 @@ type CheckResultReport struct { ReporterCountryName string `json:"reporterCountryName"` } +var checkEndpoint = "check" + func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*CheckResult, error) { - var endpoint = "check" ipAddress := html.EscapeString(ip.String()) parameters := map[string]string{ "ipAddress": ipAddress, @@ -185,17 +195,41 @@ func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*Che } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodGet, checkEndpoint, parameters, nil) if err != nil { return nil, err } - var result *CheckResult - err = c.Do(ctx, req, result) + result := CheckResult{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil + return &result, nil +} + +// CheckCached uses the client cache and the ip as a key +func (c *Client) CheckCached(ctx context.Context, ip net.IP, opts *CheckOptions) (*CheckResult, error) { + key := fmt.Sprintf("%s:%s", checkEndpoint, ip.String()) + if r, err := c.cache.Get(key); err == nil { + result := CheckResult{} + if jsonErr := json.NewDecoder(bytes.NewReader(r)).Decode(&result); jsonErr == nil { + fmt.Println("return cached") + return &result, nil + } + } + if result, err := c.Check(ctx, ip, opts); err == nil { + resultEncoded := bytes.Buffer{} + if jsonErr := json.NewEncoder(&resultEncoded).Encode(result); jsonErr != nil { + return nil, jsonErr + } + if re := resultEncoded.Bytes(); re != nil { + c.cache.Set(key, re) + } + return result, nil + } else { + return nil, err + } } type ReportsOptions struct { @@ -244,17 +278,17 @@ func (c *Client) Reports(ctx context.Context, ip net.IP, opts *ReportsOptions) ( } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } - var result *ReportsResult - err = c.Do(ctx, req, result) + result := ReportsResult{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil + return &result, nil } type BlacklistOptions struct { @@ -303,17 +337,17 @@ func (c *Client) Blacklist(ctx context.Context, opts *BlacklistOptions) (*Blackl parameters := handleBlacklistOptions(opts) - req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } - var result *BlacklistResult - err = c.Do(ctx, req, result) + result := BlacklistResult{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil + return &result, nil } func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io.Reader, error) { @@ -321,7 +355,7 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io parameters := handleBlacklistOptions(opts) - req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -332,12 +366,12 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io return nil, err } defer resp.Body.Close() - var result *bytes.Buffer - _, err = io.Copy(result, resp.Body) + result := bytes.Buffer{} + _, err = io.Copy(&result, resp.Body) if err != nil { return nil, err } - return result, nil + return &result, nil } type ReportOptions struct { @@ -353,7 +387,7 @@ type ReportResult struct { } `json:"data"` } -func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*ReportsResult, error) { +func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*ReportResult, error) { var endpoint = "report" ipAddress := html.EscapeString(ip.String()) @@ -364,8 +398,8 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R if opts != nil { if opts.Categories != nil { categories := make([]string, len(opts.Categories)) - for _, category := range opts.Categories { - categories = append(categories, strconv.Itoa(category)) + for i, category := range opts.Categories { + categories[i] = strconv.Itoa(category) } parameters["categories"] = strings.Join(categories, ",") } @@ -377,16 +411,16 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodPost, endpoint, parameters, nil) if err != nil { return nil, err } - var result *ReportsResult - err = c.Do(ctx, req, result) + result := ReportResult{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil + return &result, nil } type CheckBlockOptions struct { @@ -425,17 +459,17 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } - var result *CheckBlockResult - err = c.Do(ctx, req, result) + result := CheckBlockResult{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil + return &result, nil } type BulkReportData struct { @@ -488,7 +522,7 @@ func (b *BulkReportDatas) toSlice() [][]string { return slice } -func (b *BulkReportDatas) Csv() io.Reader { +func (b *BulkReportDatas) csv() io.Reader { result := &bytes.Buffer{} writer := csv.NewWriter(result) @@ -518,6 +552,20 @@ func (b *BulkReportDatas) validate() error { } return nil } +func toMultipartCsv(data io.Reader) (io.Reader, string) { + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + formFile, err := writer.CreateFormFile("csv", "data.csv") + if err != nil { + return nil, "" + } + _, err = io.Copy(formFile, data) + if err != nil { + return nil, "" + } + writer.Close() + return &requestBody, writer.FormDataContentType() +} type BulkReportResult struct { Data struct { @@ -539,39 +587,24 @@ func (c *Client) BulkReport(ctx context.Context, data *BulkReportDatas) (*BulkRe if err := data.validate(); err != nil { return nil, err } - csvData := data.Csv() + csvData := data.csv() if csvData == nil { return nil, errors.New("bulk report: no csv data") } requestBody, contentType := toMultipartCsv(csvData) - req, err := c.NewRequest(http.MethodGet, endpoint, nil, requestBody) + req, err := c.newRequest(http.MethodPost, endpoint, nil, requestBody) if err != nil { return nil, err } req.Header.Add("Content-Type", contentType) - var result *BulkReportResult - err = c.Do(ctx, req, &result) + result := BulkReportResult{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil -} - -func toMultipartCsv(data io.Reader) (io.Reader, string) { - var requestBody bytes.Buffer - writer := multipart.NewWriter(&requestBody) - formFile, err := writer.CreateFormFile("csv", "data.csv") - if err != nil { - return nil, "" - } - _, err = io.Copy(formFile, data) - if err != nil { - return nil, "" - } - writer.Close() - return &requestBody, writer.FormDataContentType() + return &result, nil } type ClearAddressData struct { @@ -587,14 +620,14 @@ func (c *Client) ClearAddress(ctx context.Context, ip net.IP) (*ClearAddressData "ipAddress": ipAddress, } - req, err := c.NewRequest(http.MethodPost, endpoint, parameters, nil) + req, err := c.newRequest(http.MethodDelete, endpoint, parameters, nil) if err != nil { return nil, err } - var result *ClearAddressData - err = c.Do(ctx, req, &result) + result := ClearAddressData{} + err = c.do(ctx, req, &result) if err != nil { return nil, err } - return result, nil + return &result, nil }