1
0
Fork 0

feature: implement bulk-report endpoints

This commit is contained in:
Seraphim Strub 2024-08-01 13:03:30 +00:00
parent 1009e032e4
commit e6b9839e42
2 changed files with 126 additions and 15 deletions

2
go.mod
View file

@ -1,3 +1,3 @@
module go-abuseipdb module sst.rievo.dev/go-abuseipdb
go 1.22 go 1.22

View file

@ -3,14 +3,18 @@ package abuseipdb
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/csv"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"go-abuseipdb/pkg/abuseipdb/rate"
"html" "html"
"io" "io"
"mime/multipart"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"sst.rievo.dev/go-abuseipdb/pkg/abuseipdb/rate"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -59,7 +63,7 @@ func (c *Client) initialize() {
} }
} }
func (c *Client) NewRequest(method, urlStr string, parameter map[string]string) (*http.Request, error) { func (c *Client) NewRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) {
if !strings.HasSuffix(c.BaseURL.Path, "/") { if !strings.HasSuffix(c.BaseURL.Path, "/") {
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL) return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
} }
@ -69,7 +73,7 @@ func (c *Client) NewRequest(method, urlStr string, parameter map[string]string)
return nil, err return nil, err
} }
req, err := http.NewRequest(method, u.String(), nil) req, err := http.NewRequest(method, u.String(), body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -181,7 +185,7 @@ func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*Che
} }
} }
req, err := c.NewRequest(http.MethodGet, endpoint, parameters) req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,7 +244,7 @@ func (c *Client) Reports(ctx context.Context, ip net.IP, opts *ReportsOptions) (
} }
} }
req, err := c.NewRequest(http.MethodGet, endpoint, parameters) req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,7 +303,7 @@ func (c *Client) Blacklist(ctx context.Context, opts *BlacklistOptions) (*Blackl
parameters := handleBlacklistOptions(opts) parameters := handleBlacklistOptions(opts)
req, err := c.NewRequest(http.MethodGet, endpoint, parameters) req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -317,7 +321,7 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io
parameters := handleBlacklistOptions(opts) parameters := handleBlacklistOptions(opts)
req, err := c.NewRequest(http.MethodGet, endpoint, parameters) req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -373,7 +377,7 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R
} }
} }
req, err := c.NewRequest(http.MethodGet, endpoint, parameters) req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -421,7 +425,7 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl
} }
} }
req, err := c.NewRequest(http.MethodGet, endpoint, parameters) req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -436,11 +440,85 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl
type BulkReportData struct { type BulkReportData struct {
IpAddress string `csv:"IP"` IpAddress string `csv:"IP"`
Categories []string `csv:"Categories"` Categories []int `csv:"Categories"`
ReportDate time.Time `csv:"ReportDate"` ReportDate time.Time `csv:"ReportDate"`
Comment string `csv:"Comment"` Comment string `csv:"Comment"`
} }
type BulkReportDatas []BulkReportData
func (b *BulkReportDatas) getHeaders() []string {
if b == nil || len(*b) == 0 {
return nil
}
val := reflect.ValueOf((*b)[0])
var headers []string
for i := 0; i < val.Type().NumField(); i++ {
t := val.Type().Field(i)
fieldName := t.Name
jsonTag := t.Tag.Get("csv")
parts := strings.Split(jsonTag, ",")
name := parts[0]
if name == "" {
name = fieldName
}
headers = append(headers, name)
}
return headers
}
func (b *BulkReportDatas) toSlice() [][]string {
var slice [][]string
if b == nil || len(*b) == 0 {
return nil
}
for _, data := range *b {
categories := make([]string, len(data.Categories))
for i, category := range data.Categories {
categories[i] = strconv.Itoa(category)
}
slice = append(slice, []string{
data.IpAddress,
strings.Join(categories, ","),
data.ReportDate.Format(time.RFC3339),
data.Comment,
})
}
return slice
}
func (b *BulkReportDatas) Csv() io.Reader {
result := &bytes.Buffer{}
writer := csv.NewWriter(result)
err := writer.WriteAll(append([][]string{b.getHeaders()}, b.toSlice()...))
if err != nil {
return nil
}
return result
}
func (b *BulkReportDatas) validate() error {
if b == nil || len(*b) == 0 {
return nil
}
for i, data := range *b {
if len(data.Categories) == 0 {
return fmt.Errorf("no categories found in BulkReportDatas entry %d", i)
}
if data.ReportDate.IsZero() {
return fmt.Errorf("no report date found in BulkReportDatas entry %d", i)
}
if len(data.Comment) > 1_024 {
return fmt.Errorf("comment is too long (%d bytes) for entry %d", len(data.Comment), i)
}
}
return nil
}
type BulkReportResult struct { type BulkReportResult struct {
Data struct { Data struct {
SavedReports int `json:"savedReports"` SavedReports int `json:"savedReports"`
@ -452,10 +530,43 @@ type BulkReportResult struct {
} `json:"data"` } `json:"data"`
} }
func (c *Client) BulkReport(ctx context.Context, data *BulkReportData) (*BulkReportResult, error) { func (c *Client) BulkReport(ctx context.Context, data *BulkReportDatas) (*BulkReportResult, error) {
//var endpoint = "bulk-report" var endpoint = "bulk-report"
return nil, fmt.Errorf("not implemented") if data == nil || len(*data) == 0 {
return nil, errors.New("bulk report: no data")
}
if err := data.validate(); err != nil {
return nil, err
}
csvData := data.Csv()
if csvData == nil {
return nil, errors.New("bulk report: no csv data")
}
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
formFile, err := writer.CreateFormFile("csv", "data.csv")
if err != nil {
return nil, err
}
_, err = io.Copy(formFile, csvData)
if err != nil {
return nil, err
}
writer.Close()
req, err := c.NewRequest(http.MethodGet, endpoint, nil, &requestBody)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", writer.FormDataContentType())
var result *BulkReportResult
err = c.Do(ctx, req, &result)
if err != nil {
return nil, err
}
return result, nil
} }
type ClearAddressData struct { type ClearAddressData struct {
@ -471,7 +582,7 @@ func (c *Client) ClearAddress(ctx context.Context, ip net.IP) (*ClearAddressData
"ipAddress": ipAddress, "ipAddress": ipAddress,
} }
req, err := c.NewRequest(http.MethodPost, endpoint, parameters) req, err := c.NewRequest(http.MethodPost, endpoint, parameters, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }