1
0
Fork 0

feature: add cached Check and fix result unmarshalling

This commit is contained in:
Seraphim Strub 2024-08-04 16:57:37 +00:00
parent d374239c6c
commit 8c8b86bfa0
3 changed files with 92 additions and 55 deletions

2
go.mod
View file

@ -1,3 +1,5 @@
module sst.rievo.dev/go-abuseipdb
go 1.22
require github.com/allegro/bigcache/v3 v3.1.0

2
go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/allegro/bigcache/v3 v3.1.0 h1:H2Vp8VOvxcrB91o86fUSVJFqeuz8kpyyB02eH3bSzwk=
github.com/allegro/bigcache/v3 v3.1.0/go.mod h1:aPyh7jEvrog9zAwx5N7+JUQX5dZTSGpxF1LAR4dr35I=

View file

@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/allegro/bigcache/v3"
"html"
"io"
"mime/multipart"
@ -35,8 +36,11 @@ type Client struct {
// ratelimit
RateLimit *rate.Limit
cache *bigcache.BigCache
}
// NewClient pass nil to not use a custom *http.Client
func NewClient(client *http.Client) *Client {
if client == nil {
client = &http.Client{}
@ -51,6 +55,11 @@ func (c *Client) AddApiKey(key string) {
c.APIKey = key
}
// AddCache with "ttl", suggestion 30min
func (c *Client) AddCache(eviction time.Duration) {
c.cache, _ = bigcache.New(context.Background(), bigcache.DefaultConfig(eviction))
}
func (c *Client) initialize() {
if c.client == nil {
c.client = &http.Client{}
@ -63,7 +72,7 @@ func (c *Client) initialize() {
}
}
func (c *Client) NewRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) {
func (c *Client) newRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) {
if !strings.HasSuffix(c.BaseURL.Path, "/") {
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
}
@ -111,7 +120,7 @@ func handleError(r *http.Response) error {
return nil
}
func (c *Client) Do(ctx context.Context, req *http.Request, v any) error {
func (c *Client) do(ctx context.Context, req *http.Request, v any) error {
resp, err := c.client.Do(req.WithContext(ctx))
if err != nil {
return err
@ -169,8 +178,9 @@ type CheckResultReport struct {
ReporterCountryName string `json:"reporterCountryName"`
}
var checkEndpoint = "check"
func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*CheckResult, error) {
var endpoint = "check"
ipAddress := html.EscapeString(ip.String())
parameters := map[string]string{
"ipAddress": ipAddress,
@ -185,17 +195,41 @@ func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*Che
}
}
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodGet, checkEndpoint, parameters, nil)
if err != nil {
return nil, err
}
var result *CheckResult
err = c.Do(ctx, req, result)
result := CheckResult{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}
// CheckCached uses the client cache and the ip as a key
func (c *Client) CheckCached(ctx context.Context, ip net.IP, opts *CheckOptions) (*CheckResult, error) {
key := fmt.Sprintf("%s:%s", checkEndpoint, ip.String())
if r, err := c.cache.Get(key); err == nil {
result := CheckResult{}
if jsonErr := json.NewDecoder(bytes.NewReader(r)).Decode(&result); jsonErr == nil {
fmt.Println("return cached")
return &result, nil
}
}
if result, err := c.Check(ctx, ip, opts); err == nil {
resultEncoded := bytes.Buffer{}
if jsonErr := json.NewEncoder(&resultEncoded).Encode(result); jsonErr != nil {
return nil, jsonErr
}
if re := resultEncoded.Bytes(); re != nil {
c.cache.Set(key, re)
}
return result, nil
} else {
return nil, err
}
}
type ReportsOptions struct {
@ -244,17 +278,17 @@ func (c *Client) Reports(ctx context.Context, ip net.IP, opts *ReportsOptions) (
}
}
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil {
return nil, err
}
var result *ReportsResult
err = c.Do(ctx, req, result)
result := ReportsResult{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}
type BlacklistOptions struct {
@ -303,17 +337,17 @@ func (c *Client) Blacklist(ctx context.Context, opts *BlacklistOptions) (*Blackl
parameters := handleBlacklistOptions(opts)
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil {
return nil, err
}
var result *BlacklistResult
err = c.Do(ctx, req, result)
result := BlacklistResult{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}
func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io.Reader, error) {
@ -321,7 +355,7 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io
parameters := handleBlacklistOptions(opts)
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil {
return nil, err
}
@ -332,12 +366,12 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io
return nil, err
}
defer resp.Body.Close()
var result *bytes.Buffer
_, err = io.Copy(result, resp.Body)
result := bytes.Buffer{}
_, err = io.Copy(&result, resp.Body)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}
type ReportOptions struct {
@ -353,7 +387,7 @@ type ReportResult struct {
} `json:"data"`
}
func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*ReportsResult, error) {
func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*ReportResult, error) {
var endpoint = "report"
ipAddress := html.EscapeString(ip.String())
@ -364,8 +398,8 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R
if opts != nil {
if opts.Categories != nil {
categories := make([]string, len(opts.Categories))
for _, category := range opts.Categories {
categories = append(categories, strconv.Itoa(category))
for i, category := range opts.Categories {
categories[i] = strconv.Itoa(category)
}
parameters["categories"] = strings.Join(categories, ",")
}
@ -377,16 +411,16 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R
}
}
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodPost, endpoint, parameters, nil)
if err != nil {
return nil, err
}
var result *ReportsResult
err = c.Do(ctx, req, result)
result := ReportResult{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}
type CheckBlockOptions struct {
@ -425,17 +459,17 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl
}
}
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil {
return nil, err
}
var result *CheckBlockResult
err = c.Do(ctx, req, result)
result := CheckBlockResult{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}
type BulkReportData struct {
@ -488,7 +522,7 @@ func (b *BulkReportDatas) toSlice() [][]string {
return slice
}
func (b *BulkReportDatas) Csv() io.Reader {
func (b *BulkReportDatas) csv() io.Reader {
result := &bytes.Buffer{}
writer := csv.NewWriter(result)
@ -518,6 +552,20 @@ func (b *BulkReportDatas) validate() error {
}
return nil
}
func toMultipartCsv(data io.Reader) (io.Reader, string) {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
formFile, err := writer.CreateFormFile("csv", "data.csv")
if err != nil {
return nil, ""
}
_, err = io.Copy(formFile, data)
if err != nil {
return nil, ""
}
writer.Close()
return &requestBody, writer.FormDataContentType()
}
type BulkReportResult struct {
Data struct {
@ -539,39 +587,24 @@ func (c *Client) BulkReport(ctx context.Context, data *BulkReportDatas) (*BulkRe
if err := data.validate(); err != nil {
return nil, err
}
csvData := data.Csv()
csvData := data.csv()
if csvData == nil {
return nil, errors.New("bulk report: no csv data")
}
requestBody, contentType := toMultipartCsv(csvData)
req, err := c.NewRequest(http.MethodGet, endpoint, nil, requestBody)
req, err := c.newRequest(http.MethodPost, endpoint, nil, requestBody)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", contentType)
var result *BulkReportResult
err = c.Do(ctx, req, &result)
result := BulkReportResult{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
}
func toMultipartCsv(data io.Reader) (io.Reader, string) {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
formFile, err := writer.CreateFormFile("csv", "data.csv")
if err != nil {
return nil, ""
}
_, err = io.Copy(formFile, data)
if err != nil {
return nil, ""
}
writer.Close()
return &requestBody, writer.FormDataContentType()
return &result, nil
}
type ClearAddressData struct {
@ -587,14 +620,14 @@ func (c *Client) ClearAddress(ctx context.Context, ip net.IP) (*ClearAddressData
"ipAddress": ipAddress,
}
req, err := c.NewRequest(http.MethodPost, endpoint, parameters, nil)
req, err := c.newRequest(http.MethodDelete, endpoint, parameters, nil)
if err != nil {
return nil, err
}
var result *ClearAddressData
err = c.Do(ctx, req, &result)
result := ClearAddressData{}
err = c.do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
return &result, nil
}