diff --git a/go.mod b/go.mod index 739002d..9e0f3e5 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module go-abuseipdb +module sst.rievo.dev/go-abuseipdb go 1.22 diff --git a/pkg/abuseipdb/client.go b/pkg/abuseipdb/client.go index ef44a7a..8140a05 100644 --- a/pkg/abuseipdb/client.go +++ b/pkg/abuseipdb/client.go @@ -3,14 +3,18 @@ package abuseipdb import ( "bytes" "context" + "encoding/csv" "encoding/json" + "errors" "fmt" - "go-abuseipdb/pkg/abuseipdb/rate" "html" "io" + "mime/multipart" "net" "net/http" "net/url" + "reflect" + "sst.rievo.dev/go-abuseipdb/pkg/abuseipdb/rate" "strconv" "strings" "time" @@ -59,7 +63,7 @@ func (c *Client) initialize() { } } -func (c *Client) NewRequest(method, urlStr string, parameter map[string]string) (*http.Request, error) { +func (c *Client) NewRequest(method, urlStr string, parameter map[string]string, body io.Reader) (*http.Request, error) { if !strings.HasSuffix(c.BaseURL.Path, "/") { return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL) } @@ -69,7 +73,7 @@ func (c *Client) NewRequest(method, urlStr string, parameter map[string]string) return nil, err } - req, err := http.NewRequest(method, u.String(), nil) + req, err := http.NewRequest(method, u.String(), body) if err != nil { return nil, err } @@ -181,7 +185,7 @@ func (c *Client) Check(ctx context.Context, ip net.IP, opts *CheckOptions) (*Che } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters) + req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -240,7 +244,7 @@ func (c *Client) Reports(ctx context.Context, ip net.IP, opts *ReportsOptions) ( } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters) + req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -299,7 +303,7 @@ func (c *Client) Blacklist(ctx context.Context, opts *BlacklistOptions) (*Blackl parameters := handleBlacklistOptions(opts) - req, err := c.NewRequest(http.MethodGet, endpoint, parameters) + req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -317,7 +321,7 @@ func (c *Client) BlacklistPlain(ctx context.Context, opts *BlacklistOptions) (io parameters := handleBlacklistOptions(opts) - req, err := c.NewRequest(http.MethodGet, endpoint, parameters) + req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -373,7 +377,7 @@ func (c *Client) Report(ctx context.Context, ip net.IP, opts *ReportOptions) (*R } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters) + req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -421,7 +425,7 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl } } - req, err := c.NewRequest(http.MethodGet, endpoint, parameters) + req, err := c.NewRequest(http.MethodGet, endpoint, parameters, nil) if err != nil { return nil, err } @@ -436,11 +440,85 @@ func (c *Client) CheckBlock(ctx context.Context, ipnNet net.IPNet, opts *CheckBl type BulkReportData struct { IpAddress string `csv:"IP"` - Categories []string `csv:"Categories"` + Categories []int `csv:"Categories"` ReportDate time.Time `csv:"ReportDate"` Comment string `csv:"Comment"` } +type BulkReportDatas []BulkReportData + +func (b *BulkReportDatas) getHeaders() []string { + if b == nil || len(*b) == 0 { + return nil + } + val := reflect.ValueOf((*b)[0]) + var headers []string + for i := 0; i < val.Type().NumField(); i++ { + t := val.Type().Field(i) + fieldName := t.Name + + jsonTag := t.Tag.Get("csv") + parts := strings.Split(jsonTag, ",") + name := parts[0] + if name == "" { + name = fieldName + } + headers = append(headers, name) + } + return headers +} + +func (b *BulkReportDatas) toSlice() [][]string { + var slice [][]string + if b == nil || len(*b) == 0 { + return nil + } + for _, data := range *b { + categories := make([]string, len(data.Categories)) + for i, category := range data.Categories { + categories[i] = strconv.Itoa(category) + } + slice = append(slice, []string{ + data.IpAddress, + strings.Join(categories, ","), + data.ReportDate.Format(time.RFC3339), + data.Comment, + }) + } + return slice +} + +func (b *BulkReportDatas) Csv() io.Reader { + result := &bytes.Buffer{} + writer := csv.NewWriter(result) + + err := writer.WriteAll(append([][]string{b.getHeaders()}, b.toSlice()...)) + if err != nil { + return nil + } + + return result +} + +func (b *BulkReportDatas) validate() error { + if b == nil || len(*b) == 0 { + return nil + } + + for i, data := range *b { + if len(data.Categories) == 0 { + return fmt.Errorf("no categories found in BulkReportDatas entry %d", i) + } + if data.ReportDate.IsZero() { + return fmt.Errorf("no report date found in BulkReportDatas entry %d", i) + } + if len(data.Comment) > 1_024 { + return fmt.Errorf("comment is too long (%d bytes) for entry %d", len(data.Comment), i) + } + } + return nil +} + type BulkReportResult struct { Data struct { SavedReports int `json:"savedReports"` @@ -452,10 +530,43 @@ type BulkReportResult struct { } `json:"data"` } -func (c *Client) BulkReport(ctx context.Context, data *BulkReportData) (*BulkReportResult, error) { - //var endpoint = "bulk-report" +func (c *Client) BulkReport(ctx context.Context, data *BulkReportDatas) (*BulkReportResult, error) { + var endpoint = "bulk-report" - return nil, fmt.Errorf("not implemented") + if data == nil || len(*data) == 0 { + return nil, errors.New("bulk report: no data") + } + if err := data.validate(); err != nil { + return nil, err + } + csvData := data.Csv() + if csvData == nil { + return nil, errors.New("bulk report: no csv data") + } + + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + formFile, err := writer.CreateFormFile("csv", "data.csv") + if err != nil { + return nil, err + } + _, err = io.Copy(formFile, csvData) + if err != nil { + return nil, err + } + writer.Close() + req, err := c.NewRequest(http.MethodGet, endpoint, nil, &requestBody) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", writer.FormDataContentType()) + + var result *BulkReportResult + err = c.Do(ctx, req, &result) + if err != nil { + return nil, err + } + return result, nil } type ClearAddressData struct { @@ -471,7 +582,7 @@ func (c *Client) ClearAddress(ctx context.Context, ip net.IP) (*ClearAddressData "ipAddress": ipAddress, } - req, err := c.NewRequest(http.MethodPost, endpoint, parameters) + req, err := c.NewRequest(http.MethodPost, endpoint, parameters, nil) if err != nil { return nil, err }