feature: add cached Check and fix result unmarshalling
This commit is contained in:
parent
d374239c6c
commit
8c8b86bfa0
3 changed files with 92 additions and 55 deletions
2
go.mod
2
go.mod
|
@ -1,3 +1,5 @@
|
|||
module sst.rievo.dev/go-abuseipdb
|
||||
|
||||
go 1.22
|
||||
|
||||
require github.com/allegro/bigcache/v3 v3.1.0
|
||||
|
|
2
go.sum
Normal file
2
go.sum
Normal file
|
@ -0,0 +1,2 @@
|
|||
github.com/allegro/bigcache/v3 v3.1.0 h1:H2Vp8VOvxcrB91o86fUSVJFqeuz8kpyyB02eH3bSzwk=
|
||||
github.com/allegro/bigcache/v3 v3.1.0/go.mod h1:aPyh7jEvrog9zAwx5N7+JUQX5dZTSGpxF1LAR4dr35I=
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/allegro/bigcache/v3"
|
||||
"html"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
|
@ -35,8 +36,11 @@ type Client struct {
|
|||
|
||||
// ratelimit
|
||||
RateLimit *rate.Limit
|
||||
|
||||
cache *bigcache.BigCache
|
||||
}
|
||||
|
||||
// NewClient pass nil to not use a custom *http.Client
|
||||
func NewClient(client *http.Client) *Client {
|
||||
if client == nil {
|
||||
client = &http.Client{}
|
||||
|
@ -51,6 +55,11 @@ func (c *Client) AddApiKey(key string) {
|
|||
c.APIKey = key
|
||||
}
|
||||
|
||||
// AddCache with "ttl", suggestion 30min
|
||||
func (c *Client) AddCache(eviction time.Duration) {
|
||||
c.cache, _ = bigcache.New(context.Background(), bigcache.DefaultConfig(eviction))
|
||||
}
|
||||
|
||||
func (c *Client) initialize() {
|
||||
if c.client == nil {
|
||||
c.client = &http.Client{}
|
||||
|
@ -63,7 +72,7 @@ func (c *Client) initialize() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Client) NewRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) {
|
||||
func (c *Client) newRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) {
|
||||
if !strings.HasSuffix(c.BaseURL.Path, "/") {
|
||||
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
|
||||
}
|
||||
|
@ -111,7 +120,7 @@ func handleError(r *http.Response) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Do(ctx context.Context, req *http.Request, v any) error {
|
||||
func (c *Client) do(ctx context.Context, req *http.Request, v any) error {
|
||||
resp, err := c.client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -169,8 +178,9 @@ type CheckResultReport struct {
|
|||
ReporterCountryName string `json:"reporterCountryName"`
|
||||
}
|
||||
|
||||
var checkEndpoint = "check"
|
||||
|
||||
func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*CheckResult, error) {
|
||||
var endpoint = "check"
|
||||
ipAddress := html.EscapeString(ip.String())
|
||||
parameters := map[string]string{
|
||||
"ipAddress": ipAddress,
|
||||
|
@ -185,17 +195,41 @@ func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*Che
|
|||
}
|
||||
}
|
||||
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodGet, checkEndpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result *CheckResult
|
||||
err = c.Do(ctx, req, result)
|
||||
result := CheckResult{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CheckCached uses the client cache and the ip as a key
|
||||
func (c *Client) CheckCached(ctx context.Context, ip net.IP, opts *CheckOptions) (*CheckResult, error) {
|
||||
key := fmt.Sprintf("%s:%s", checkEndpoint, ip.String())
|
||||
if r, err := c.cache.Get(key); err == nil {
|
||||
result := CheckResult{}
|
||||
if jsonErr := json.NewDecoder(bytes.NewReader(r)).Decode(&result); jsonErr == nil {
|
||||
fmt.Println("return cached")
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
if result, err := c.Check(ctx, ip, opts); err == nil {
|
||||
resultEncoded := bytes.Buffer{}
|
||||
if jsonErr := json.NewEncoder(&resultEncoded).Encode(result); jsonErr != nil {
|
||||
return nil, jsonErr
|
||||
}
|
||||
if re := resultEncoded.Bytes(); re != nil {
|
||||
c.cache.Set(key, re)
|
||||
}
|
||||
return result, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
type ReportsOptions struct {
|
||||
|
@ -244,17 +278,17 @@ func (c *Client) Reports(ctx context.Context, ip net.IP, opts *ReportsOptions) (
|
|||
}
|
||||
}
|
||||
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result *ReportsResult
|
||||
err = c.Do(ctx, req, result)
|
||||
result := ReportsResult{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
type BlacklistOptions struct {
|
||||
|
@ -303,17 +337,17 @@ func (c *Client) Blacklist(ctx context.Context, opts *BlacklistOptions) (*Blackl
|
|||
|
||||
parameters := handleBlacklistOptions(opts)
|
||||
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result *BlacklistResult
|
||||
err = c.Do(ctx, req, result)
|
||||
result := BlacklistResult{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io.Reader, error) {
|
||||
|
@ -321,7 +355,7 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io
|
|||
|
||||
parameters := handleBlacklistOptions(opts)
|
||||
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -332,12 +366,12 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io
|
|||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var result *bytes.Buffer
|
||||
_, err = io.Copy(result, resp.Body)
|
||||
result := bytes.Buffer{}
|
||||
_, err = io.Copy(&result, resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
type ReportOptions struct {
|
||||
|
@ -353,7 +387,7 @@ type ReportResult struct {
|
|||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*ReportsResult, error) {
|
||||
func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*ReportResult, error) {
|
||||
var endpoint = "report"
|
||||
|
||||
ipAddress := html.EscapeString(ip.String())
|
||||
|
@ -364,8 +398,8 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R
|
|||
if opts != nil {
|
||||
if opts.Categories != nil {
|
||||
categories := make([]string, len(opts.Categories))
|
||||
for _, category := range opts.Categories {
|
||||
categories = append(categories, strconv.Itoa(category))
|
||||
for i, category := range opts.Categories {
|
||||
categories[i] = strconv.Itoa(category)
|
||||
}
|
||||
parameters["categories"] = strings.Join(categories, ",")
|
||||
}
|
||||
|
@ -377,16 +411,16 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R
|
|||
}
|
||||
}
|
||||
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodPost, endpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result *ReportsResult
|
||||
err = c.Do(ctx, req, result)
|
||||
result := ReportResult{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
type CheckBlockOptions struct {
|
||||
|
@ -425,17 +459,17 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl
|
|||
}
|
||||
}
|
||||
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodGet, endpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result *CheckBlockResult
|
||||
err = c.Do(ctx, req, result)
|
||||
result := CheckBlockResult{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
type BulkReportData struct {
|
||||
|
@ -488,7 +522,7 @@ func (b *BulkReportDatas) toSlice() [][]string {
|
|||
return slice
|
||||
}
|
||||
|
||||
func (b *BulkReportDatas) Csv() io.Reader {
|
||||
func (b *BulkReportDatas) csv() io.Reader {
|
||||
result := &bytes.Buffer{}
|
||||
writer := csv.NewWriter(result)
|
||||
|
||||
|
@ -518,6 +552,20 @@ func (b *BulkReportDatas) validate() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
func toMultipartCsv(data io.Reader) (io.Reader, string) {
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
formFile, err := writer.CreateFormFile("csv", "data.csv")
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
_, err = io.Copy(formFile, data)
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
writer.Close()
|
||||
return &requestBody, writer.FormDataContentType()
|
||||
}
|
||||
|
||||
type BulkReportResult struct {
|
||||
Data struct {
|
||||
|
@ -539,39 +587,24 @@ func (c *Client) BulkReport(ctx context.Context, data *BulkReportDatas) (*BulkRe
|
|||
if err := data.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csvData := data.Csv()
|
||||
csvData := data.csv()
|
||||
if csvData == nil {
|
||||
return nil, errors.New("bulk report: no csv data")
|
||||
}
|
||||
|
||||
requestBody, contentType := toMultipartCsv(csvData)
|
||||
req, err := c.NewRequest(http.MethodGet, endpoint, nil, requestBody)
|
||||
req, err := c.newRequest(http.MethodPost, endpoint, nil, requestBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", contentType)
|
||||
|
||||
var result *BulkReportResult
|
||||
err = c.Do(ctx, req, &result)
|
||||
result := BulkReportResult{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func toMultipartCsv(data io.Reader) (io.Reader, string) {
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
formFile, err := writer.CreateFormFile("csv", "data.csv")
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
_, err = io.Copy(formFile, data)
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
writer.Close()
|
||||
return &requestBody, writer.FormDataContentType()
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
type ClearAddressData struct {
|
||||
|
@ -587,14 +620,14 @@ func (c *Client) ClearAddress(ctx context.Context, ip net.IP) (*ClearAddressData
|
|||
"ipAddress": ipAddress,
|
||||
}
|
||||
|
||||
req, err := c.NewRequest(http.MethodPost, endpoint, parameters, nil)
|
||||
req, err := c.newRequest(http.MethodDelete, endpoint, parameters, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result *ClearAddressData
|
||||
err = c.Do(ctx, req, &result)
|
||||
result := ClearAddressData{}
|
||||
err = c.do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue