├── .chglog
├── CHANGELOG.tpl.md
└── config.yml
├── .github
├── dependabot.yml
└── workflows
│ ├── build-test.yml
│ ├── changelog-update.yml
│ ├── codeql-analysis.yml
│ ├── dep-auto-merge.yml
│ ├── lint-test.yml
│ └── release-tag.yml
├── .gitignore
├── CHANGELOG.md
├── LICENSE.md
├── README.md
├── async
├── async.go
└── async_test.go
├── auth
└── pdcp
│ ├── auth.go
│ ├── creds.go
│ └── creds_test.go
├── batcher
├── batcher.go
├── batcher_test.go
└── doc.go
├── buffer
└── disk.go
├── channelutil
├── README.md
├── clone.go
├── clone_join_test.go
├── join.go
└── utils.go
├── conn
└── connpool
│ ├── inflight.go
│ └── onetimepool.go
├── consts
└── errors.go
├── context
├── NContext.go
├── Ncontext_test.go
├── context.go
└── context_test.go
├── conversion
├── conversion.go
└── conversion_test.go
├── crypto
├── README.md
├── hash.go
├── hash_test.go
├── jarm
│ └── jarm.go
├── tls.go
└── ztls.go
├── dedupe
├── dedupe.go
├── dedupe_test.go
├── leveldb.go
└── map.go
├── dns
├── dnsutil.go
└── dnsutil_test.go
├── env
├── env.go
└── env_test.go
├── errkit
├── README.md
├── errors.go
├── errors_test.go
├── helpers.go
├── interfaces.go
└── kind.go
├── errors
├── enriched.go
├── err_test.go
├── err_with_fmt.go
├── err_with_fmt_test.go
├── errinterface.go
├── errlevel.go
└── errors.go
├── exec
├── README.md
├── executil.go
└── executil_test.go
├── file
├── README.md
├── clean.go
├── clean_test.go
├── file.go
├── file_test.go
└── tests
│ ├── empty_lines.txt
│ ├── path-traversal.txt
│ ├── pipe_separator.txt
│ └── standard.txt
├── folder
├── README.md
├── folderutil.go
├── folderutil_linux_test.go
├── folderutil_test.go
├── folderutil_win_test.go
└── std_dirs.go
├── generic
├── generic.go
├── generic_test.go
├── lockable.go
└── lockable_test.go
├── global
└── max_threads.go
├── go.mod
├── go.sum
├── healthcheck
├── connection.go
├── connection_test.go
├── dns.go
├── dns_test.go
├── environment.go
├── environment_test.go
├── healthcheck.go
├── path_permission.go
└── path_permission_test.go
├── http
├── README.md
├── chain.go
├── httputil.go
├── httputil_test.go
├── internal.go
├── normalization.go
├── respChain.go
└── response.go
├── io
├── io.go
└── io_test.go
├── ip
├── README.md
├── iputil.go
└── iputil_test.go
├── log
├── README.md
├── logutil.go
└── logutil_test.go
├── maps
├── README.md
├── generic_map.go
├── generic_map_test.go
├── mapsutil.go
├── mapsutil_test.go
├── ordered_map.go
├── ordered_map_test.go
├── synclock_map.go
└── synclock_map_test.go
├── memguardian
├── README.MD
├── doc.go
├── memguardian.go
├── memory.go
├── memory_linux.go
└── memory_others.go
├── memoize
├── cmd
│ └── main.go
├── gen
│ └── generic
│ │ └── memoize.go
├── memoize.go
├── memoize_test.go
├── package_template.tpl
├── simpleflight
│ └── simpleflight.go
├── templates.go
└── tests
│ └── test.go
├── ml
├── metrics
│ ├── classification_report.go
│ └── confusion_matrix.go
├── model_selection
│ └── model_selection.go
├── naive_bayes
│ ├── naive_bayes_classifier.go
│ └── naive_bayes_classifier_test.go
└── types.go
├── net
├── net.go
└── net_test.go
├── os
├── arch.go
└── os.go
├── patterns
├── doc.go
├── patterns.go
└── patterns_test.go
├── permission
├── README.md
├── error.go
├── permission.go
├── permission_file.go
├── permission_file_test.go
├── permission_linux.go
├── permission_other.go
├── permission_test.go
└── permission_win.go
├── ports
├── ports.go
└── ports_test.go
├── pprof
├── README.md
├── pprof.go
└── server.go
├── process
├── docker.go
└── process.go
├── proxy
├── README.md
├── burp.go
├── proxy.go
└── proxy_test.go
├── ptr
├── ptr.go
└── ptr_test.go
├── race
├── README.md
├── norace.go
└── race.go
├── rand
├── number.go
└── number_test.go
├── reader
├── conn_read.go
├── conn_read_test.go
├── error.go
├── examples
│ └── keypress
│ │ ├── buffered
│ │ └── keypress.go
│ │ └── raw
│ │ └── keypress.go
├── frozen_reader.go
├── frozen_reader_test.go
├── rawmode
│ ├── raw_mode.go
│ ├── raw_mode_posix.go
│ ├── raw_mode_windows.go
│ ├── values_darwin.go
│ └── values_linux.go
├── reader_keypress.go
├── reusable_read_closer.go
├── reusable_read_closer_test.go
├── timeout_reader.go
└── timeout_reader_test.go
├── reflect
├── README.md
├── reflectutil.go
├── reflectutil_test.go
└── tests
│ └── tests.go
├── routing
├── router.go
├── router_darwin.go
├── router_linux.go
└── router_windows.go
├── scripts
├── README.md
└── versionbump
│ ├── versionbump.go
│ └── versionbump_test.go
├── slice
├── README.md
├── sliceutil.go
├── sliceutil_test.go
├── sync_slice.go
└── sync_slice_test.go
├── strings
├── README.md
├── strings_encoding.go
├── strings_normalize.go
├── stringsutil.go
└── stringsutil_test.go
├── structs
├── structs.go
└── structs_test.go
├── sync
├── adaptivewaitgroup.go
├── adaptivewaitgroup_test.go
├── semaphore
│ └── semaphore.go
└── sizedpool
│ ├── sizedpool.go
│ └── sizedpool_test.go
├── syscallutil
├── syscall_unix.go
├── syscall_unix_others.go
├── syscallutil.go
├── syscallutil_test.go
└── syscallutil_win.go
├── sysutil
├── sysutil.go
└── sysutil_test.go
├── time
├── README.md
├── timeutil.go
└── timeutil_test.go
├── trace
├── trace.go
└── trace_test.go
├── unit
├── doc.go
└── size.go
├── update
├── gh.go
├── gh_test.go
├── types.go
├── types_test.go
├── update.go
├── utils_all.go
├── utils_linux.go
└── utils_test.go
└── url
├── README.md
├── merge_test.go
├── orderedparams.go
├── orderedparams_test.go
├── parsers.go
├── rawparam.go
├── rawparam_test.go
├── url.go
├── url_test.go
└── utils.go
/.chglog/CHANGELOG.tpl.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7 |
8 | {{ if .Versions -}}
9 |
10 | ## [Unreleased]
11 |
12 | {{ if .Unreleased.CommitGroups -}}
13 | {{ range .Unreleased.CommitGroups -}}
14 | ### {{ .Title }}
15 | {{ range .Commits -}}
16 | - {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }}
17 | {{ end }}
18 | {{ end -}}
19 | {{ end -}}
20 | {{ end -}}
21 |
22 | {{ range .Versions }}
23 |
24 | ## {{ if .Tag.Previous }}[{{ .Tag.Name }}]{{ else }}{{ .Tag.Name }}{{ end }} - {{ datetime "2006-01-02" .Tag.Date }}
25 | {{ range .CommitGroups -}}
26 | ### {{ .Title }}
27 | {{ range .Commits -}}
28 | - {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }}
29 | {{ end }}
30 | {{ end -}}
31 |
32 | {{- if .NoteGroups -}}
33 | {{ range .NoteGroups -}}
34 | ### {{ .Title }}
35 | {{ range .Notes }}
36 | {{ .Body }}
37 | {{ end }}
38 | {{ end -}}
39 | {{ end -}}
40 | {{ end -}}
41 |
42 | {{- if .Versions }}
43 | [Unreleased]: {{ .Info.RepositoryURL }}/compare/{{ $latest := index .Versions 0 }}{{ $latest.Tag.Name }}...HEAD
44 | {{ range .Versions -}}
45 | {{ if .Tag.Previous -}}
46 | [{{ .Tag.Name }}]: {{ $.Info.RepositoryURL }}/compare/{{ .Tag.Previous.Name }}...{{ .Tag.Name }}
47 | {{ end -}}
48 | {{ end -}}
49 | {{ end -}}
--------------------------------------------------------------------------------
/.chglog/config.yml:
--------------------------------------------------------------------------------
1 | style: github
2 | template: CHANGELOG.tpl.md
3 | info:
4 | title: CHANGELOG
5 | repository_url: https://github.com/projectdiscovery/utils
6 | options:
7 | commits:
8 | # filters:
9 | # Type:
10 | # - feat
11 | # - fix
12 | # - perf
13 | # - refactor
14 | commit_groups:
15 | # title_maps:
16 | # feat: Features
17 | # fix: Bug Fixes
18 | # perf: Performance Improvements
19 | # refactor: Code Refactoring
20 | header:
21 | pattern: "^((\\w+)\\s.*)$"
22 | pattern_maps:
23 | - Subject
24 | - Type
25 | notes:
26 | keywords:
27 | - BREAKING CHANGE
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 |
6 | version: 2
7 | updates:
8 |
9 | # Maintain dependencies for go modules
10 | - package-ecosystem: "gomod"
11 | directory: "/"
12 | schedule:
13 | interval: "weekly"
14 | target-branch: "main"
15 | commit-message:
16 | prefix: "chore"
17 | include: "scope"
18 | allow:
19 | - dependency-name: "github.com/projectdiscovery/*"
20 | groups:
21 | modules:
22 | patterns: ["github.com/projectdiscovery/*"]
23 | labels:
24 | - "Type: Maintenance"
25 |
26 | # # Maintain dependencies for docker
27 | # - package-ecosystem: "docker"
28 | # directory: "/"
29 | # schedule:
30 | # interval: "weekly"
31 | # target-branch: "main"
32 | # commit-message:
33 | # prefix: "chore"
34 | # include: "scope"
35 | #
36 | # # Maintain dependencies for GitHub Actions
37 | # - package-ecosystem: "github-actions"
38 | # directory: "/"
39 | # schedule:
40 | # interval: "weekly"
41 | # target-branch: "main"
42 | # commit-message:
43 | # prefix: "chore"
44 | # include: "scope"
45 |
--------------------------------------------------------------------------------
/.github/workflows/build-test.yml:
--------------------------------------------------------------------------------
1 | name: 🔨 Build Test
2 |
3 | on:
4 | pull_request:
5 | workflow_dispatch:
6 |
7 | jobs:
8 | build:
9 | name: Test Builds
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | matrix:
13 | os: [ubuntu-latest, windows-latest, macOS-latest]
14 | steps:
15 | - name: Set up Go
16 | uses: actions/setup-go@v4
17 | with:
18 | go-version: 1.21.x
19 |
20 | - name: Check out code
21 | uses: actions/checkout@v3
22 |
23 | - name: Test
24 | run: go test ./...
25 |
26 | - name: Race Condition Tests
27 | if: ${{ matrix.os != 'windows-latest' }} # false positives in windows
28 | run: go test -race ./...
29 |
30 | - name: Fuzz File Read # fuzz tests need to be run separately
31 | run: go test -fuzztime=10s -fuzz=FuzzSafeOpen -run "FuzzSafeOpen" ./file/...
32 |
--------------------------------------------------------------------------------
/.github/workflows/changelog-update.yml:
--------------------------------------------------------------------------------
1 | name: Update Changelog
2 |
3 | on:
4 | push:
5 | tags:
6 | - '*'
7 | workflow_dispatch:
8 |
9 | jobs:
10 | update-changelog:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Set up Go
14 | uses: actions/setup-go@v4
15 | with:
16 | go-version: 1.21.x
17 |
18 | - name: Checkout code
19 | uses: actions/checkout@v3
20 | with:
21 | fetch-depth: 0
22 |
23 | - name: Update Changelog
24 | run: |
25 | go install github.com/git-chglog/git-chglog/cmd/git-chglog@latest
26 | git-chglog -o CHANGELOG.md
27 |
28 | - name: Commit changes
29 | run: |
30 | git config --local user.email "action@github.com"
31 | git config --local user.name "GitHub Action"
32 | git commit -a -m "update CHANGELOG.md"
33 |
34 | - name: Push changes
35 | uses: ad-m/github-push-action@master
36 | with:
37 | github_token: ${{ secrets.GITHUB_TOKEN }}
38 | branch: ${{ github.ref }}
39 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | name: 🚨 CodeQL Analysis
2 |
3 | on:
4 | workflow_dispatch:
5 | pull_request:
6 | branches:
7 | - main
8 |
9 | jobs:
10 | analyze:
11 | name: Analyze
12 | runs-on: ubuntu-latest
13 | permissions:
14 | actions: read
15 | contents: read
16 | security-events: write
17 |
18 | strategy:
19 | fail-fast: false
20 | matrix:
21 | language: [ 'go' ]
22 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
23 |
24 | steps:
25 | - name: Checkout repository
26 | uses: actions/checkout@v3
27 |
28 | # Initializes the CodeQL tools for scanning.
29 | - name: Initialize CodeQL
30 | uses: github/codeql-action/init@v2
31 | with:
32 | languages: ${{ matrix.language }}
33 |
34 | - name: Autobuild
35 | uses: github/codeql-action/autobuild@v2
36 |
37 | - name: Perform CodeQL Analysis
38 | uses: github/codeql-action/analyze@v2
39 |
--------------------------------------------------------------------------------
/.github/workflows/dep-auto-merge.yml:
--------------------------------------------------------------------------------
1 | name: 🤖 dep auto merge
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | workflow_dispatch:
8 |
9 | permissions:
10 | pull-requests: write
11 | issues: write
12 | repository-projects: write
13 |
14 | jobs:
15 | automerge:
16 | runs-on: ubuntu-latest
17 | if: github.actor == 'dependabot[bot]'
18 | steps:
19 | - uses: actions/checkout@v3
20 | with:
21 | token: ${{ secrets.DEPENDABOT_PAT }}
22 |
23 | - uses: ahmadnassri/action-dependabot-auto-merge@v2
24 | with:
25 | github-token: ${{ secrets.DEPENDABOT_PAT }}
26 | target: all
27 |
--------------------------------------------------------------------------------
/.github/workflows/lint-test.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | name: 🙏🏻 Lint Test
4 | on:
5 | pull_request:
6 | workflow_dispatch:
7 |
8 | jobs:
9 | lint:
10 | name: Lint Test
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout code
14 | uses: actions/checkout@v3
15 | - name: Set up Go
16 | uses: actions/setup-go@v4
17 | with:
18 | go-version: 1.21.x
19 | - name: Run golangci-lint
20 | uses: golangci/golangci-lint-action@v3.7.0
21 | with:
22 | version: latest
23 | args: --timeout 5m
24 | working-directory: .
25 |
--------------------------------------------------------------------------------
/.github/workflows/release-tag.yml:
--------------------------------------------------------------------------------
1 | name: 🔖 Auto release gh action
2 |
3 | on:
4 | workflow_dispatch:
5 | schedule:
6 | - cron: '0 0 * * 0'
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Check out code
13 | uses: actions/checkout@v3
14 | with:
15 | fetch-depth: 0
16 |
17 | - name: Get Commit Count
18 | id: get_commit
19 | run: git rev-list `git rev-list --tags --no-walk --max-count=1`..HEAD --count | xargs -I {} echo COMMIT_COUNT={} >> $GITHUB_OUTPUT
20 |
21 | - name: Create release and tag
22 | if: ${{ steps.get_commit.outputs.COMMIT_COUNT > 0 }}
23 | id: tag_version
24 | uses: mathieudutour/github-tag-action@v6.1
25 | with:
26 | github_token: ${{ secrets.GITHUB_TOKEN }}
27 |
28 | - name: Create a GitHub release
29 | if: ${{ steps.get_commit.outputs.COMMIT_COUNT > 0 }}
30 | uses: actions/create-release@v1
31 | env:
32 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
33 | with:
34 | tag_name: ${{ steps.tag_version.outputs.new_tag }}
35 | release_name: Release ${{ steps.tag_version.outputs.new_tag }}
36 | body: ${{ steps.tag_version.outputs.changelog }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .vscode
3 | *.exe
4 |
5 | .devcontainer
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 ProjectDiscovery, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # utils
2 | The package contains various helpers libraries
--------------------------------------------------------------------------------
/async/async.go:
--------------------------------------------------------------------------------
1 | package async
2 |
3 | import "context"
4 |
5 | // Future mimics the async/await paradigm
6 | type Future[T any] interface {
7 | Await() (T, error)
8 | }
9 |
10 | type future[T any] struct {
11 | await func(ctx context.Context) (T, error)
12 | }
13 |
14 | func (f future[T]) Await() (T, error) {
15 | return f.await(context.Background())
16 | }
17 |
18 | func Exec[T any](f func() (T, error)) Future[T] {
19 | var (
20 | result T
21 | err error
22 | )
23 | c := make(chan struct{})
24 | go func() {
25 | defer close(c)
26 |
27 | result, err = f()
28 | }()
29 | return future[T]{
30 | await: func(ctx context.Context) (T, error) {
31 | select {
32 | case <-ctx.Done():
33 | return result, ctx.Err()
34 | case <-c:
35 | return result, err
36 | }
37 | },
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/async/async_test.go:
--------------------------------------------------------------------------------
1 | package async
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestAsync(t *testing.T) {
11 | // Async
12 | do := Exec(func() (bool, error) {
13 | time.Sleep(2 * time.Second)
14 | return true, nil
15 | })
16 |
17 | // do some other stuff
18 | time.Sleep(time.Second)
19 |
20 | // Await
21 | ok, err := do.Await()
22 | require.Nil(t, err)
23 | require.True(t, ok)
24 | }
25 |
--------------------------------------------------------------------------------
/auth/pdcp/creds_test.go:
--------------------------------------------------------------------------------
1 | package pdcp
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "strings"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | var exampleCred = `
13 | - username: test
14 | email: test@projectdiscovery.io
15 | api-key: testpassword
16 | server: https://scanme.sh
17 | `
18 |
19 | func TestLoadCreds(t *testing.T) {
20 | // temporarily change PDCP file location for testing
21 | f, err := os.CreateTemp("", "creds-test-*")
22 | require.Nil(t, err)
23 | _, _ = f.WriteString(strings.TrimSpace(exampleCred))
24 | defer os.Remove(f.Name())
25 | PDCPCredFile = f.Name()
26 | PDCPDir = filepath.Dir(f.Name())
27 | h := &PDCPCredHandler{}
28 | value, err := h.GetCreds()
29 | require.Nil(t, err)
30 | require.NotNil(t, value)
31 | require.Equal(t, "test", value.Username)
32 | require.Equal(t, "testpassword", value.APIKey)
33 | require.Equal(t, "https://scanme.sh", value.Server)
34 | }
35 |
--------------------------------------------------------------------------------
/batcher/batcher_test.go:
--------------------------------------------------------------------------------
1 | package batcher
2 |
3 | import (
4 | "crypto/rand"
5 | "testing"
6 | "time"
7 |
8 | "github.com/stretchr/testify/require"
9 | )
10 |
11 | func TestBatcherStandard(t *testing.T) {
12 | var (
13 | batchSize = 100
14 | wanted = 100000
15 | minWantedBatches = wanted / batchSize
16 | got int
17 | gotBatches int
18 | )
19 | callback := func(t []int) {
20 | gotBatches++
21 | for range t {
22 | got++
23 | }
24 | }
25 | bat := New[int](
26 | WithMaxCapacity[int](batchSize),
27 | WithFlushCallback[int](callback),
28 | )
29 |
30 | bat.Run()
31 |
32 | for i := 0; i < wanted; i++ {
33 | bat.Append(i)
34 | }
35 |
36 | bat.Stop()
37 |
38 | bat.WaitDone()
39 |
40 | require.Equal(t, wanted, got)
41 | require.True(t, minWantedBatches <= gotBatches)
42 | }
43 |
44 | func TestBatcherWithInterval(t *testing.T) {
45 | var (
46 | batchSize = 200
47 | wanted = 1000
48 | minWantedBatches = 10
49 | got int
50 | gotBatches int
51 | )
52 | callback := func(t []int) {
53 | gotBatches++
54 | for range t {
55 | got++
56 | }
57 | }
58 | bat := New[int](
59 | WithMaxCapacity[int](batchSize),
60 | WithFlushCallback[int](callback),
61 | WithFlushInterval[int](10*time.Millisecond),
62 | )
63 |
64 | bat.Run()
65 |
66 | for i := 0; i < wanted; i++ {
67 | time.Sleep(2 * time.Millisecond)
68 | bat.Append(i)
69 | }
70 |
71 | bat.Stop()
72 |
73 | bat.WaitDone()
74 |
75 | require.Equal(t, wanted, got)
76 | require.True(t, minWantedBatches <= gotBatches)
77 | }
78 |
79 | type exampleBatcherStruct struct {
80 | Value []byte
81 | }
82 |
83 | func TestBatcherWithSizeLimit(t *testing.T) {
84 | var (
85 | batchSize = 100
86 | maxSize = 1000
87 | wanted = 10
88 | gotBatches int
89 | )
90 | var failedIteration bool
91 |
92 | callback := func(ta []exampleBatcherStruct) {
93 | gotBatches++
94 |
95 | totalValueSize := 0
96 | for _, t := range ta {
97 | totalValueSize += len(t.Value)
98 | }
99 | if totalValueSize > maxSize {
100 | failedIteration = true
101 | }
102 | }
103 | bat := New[exampleBatcherStruct](
104 | WithMaxCapacity[exampleBatcherStruct](batchSize),
105 | WithMaxSize[exampleBatcherStruct](int32(maxSize)),
106 | WithFlushCallback[exampleBatcherStruct](callback),
107 | )
108 |
109 | bat.Run()
110 |
111 | for i := 0; i < wanted; i++ {
112 | randData := make([]byte, 200)
113 | _, _ = rand.Read(randData)
114 | bat.Append(exampleBatcherStruct{Value: randData})
115 | }
116 |
117 | bat.Stop()
118 |
119 | bat.WaitDone()
120 |
121 | require.False(t, failedIteration)
122 | }
123 |
--------------------------------------------------------------------------------
/batcher/doc.go:
--------------------------------------------------------------------------------
1 | // batcher is a package that provides a simple batching mechanism
2 | // the buffer can be configured with a max capacity and a flush interval
3 | // the buffer will invoke a callback function when the buffer is full or the flush interval is reached
4 | package batcher
5 |
--------------------------------------------------------------------------------
/buffer/disk.go:
--------------------------------------------------------------------------------
1 | package buffer
2 |
3 | import (
4 | "io"
5 | "os"
6 | )
7 |
8 | type DiskBuffer struct {
9 | f *os.File
10 | }
11 |
12 | func New() (*DiskBuffer, error) {
13 | f, err := os.CreateTemp("", "")
14 | if err != nil {
15 | return nil, err
16 | }
17 |
18 | return &DiskBuffer{f: f}, nil
19 | }
20 |
21 | func (db *DiskBuffer) Write(b []byte) (int, error) {
22 | return db.f.Write(b)
23 | }
24 |
25 | func (db *DiskBuffer) WriteAt(b []byte, off int64) (int, error) {
26 | return db.f.WriteAt(b, off)
27 | }
28 |
29 | func (db *DiskBuffer) WriteString(s string) (int, error) {
30 | return db.f.WriteString(s)
31 | }
32 |
33 | func (db *DiskBuffer) Bytes() ([]byte, error) {
34 | return os.ReadFile(db.f.Name())
35 | }
36 |
37 | func (db *DiskBuffer) String() (string, error) {
38 | data, err := db.Bytes()
39 | return string(data), err
40 | }
41 |
42 | // all readers must be closed to avoid FD leak
43 | func (db *DiskBuffer) Reader() (io.ReadSeekCloser, error) {
44 | f, err := os.Open(db.f.Name())
45 | return f, err
46 | }
47 |
48 | func (db *DiskBuffer) Close() {
49 | name := db.f.Name()
50 | db.f.Close()
51 | os.RemoveAll(name)
52 | }
53 |
--------------------------------------------------------------------------------
/channelutil/utils.go:
--------------------------------------------------------------------------------
1 | package channelutil
2 |
3 | // CreateNChannels creates and returns N channels
4 | func CreateNChannels[T any](count int, bufflen int) map[int]chan T {
5 | x := map[int]chan T{}
6 |
7 | for i := 0; i < count; i++ {
8 | x[i] = make(chan T, bufflen)
9 | }
10 | return x
11 | }
12 |
--------------------------------------------------------------------------------
/conn/connpool/inflight.go:
--------------------------------------------------------------------------------
1 | package connpool
2 |
3 | import (
4 | "errors"
5 | "net"
6 |
7 | mapsutil "github.com/projectdiscovery/utils/maps"
8 | "go.uber.org/multierr"
9 | )
10 |
11 | type InFlightConns struct {
12 | inflightConns *mapsutil.SyncLockMap[net.Conn, struct{}]
13 | }
14 |
15 | func NewInFlightConns() (*InFlightConns, error) {
16 | m := &mapsutil.SyncLockMap[net.Conn, struct{}]{
17 | Map: mapsutil.Map[net.Conn, struct{}]{},
18 | }
19 | return &InFlightConns{inflightConns: m}, nil
20 | }
21 |
22 | func (i *InFlightConns) Add(conn net.Conn) {
23 | _ = i.inflightConns.Set(conn, struct{}{})
24 | }
25 |
26 | func (i *InFlightConns) Remove(conn net.Conn) {
27 | i.inflightConns.Delete(conn)
28 | }
29 |
30 | func (i *InFlightConns) Close() error {
31 | var errs []error
32 |
33 | _ = i.inflightConns.Iterate(func(conn net.Conn, _ struct{}) error {
34 | if err := conn.Close(); err != nil {
35 | errs = append(errs, err)
36 | }
37 | return nil
38 | })
39 |
40 | if ok := i.inflightConns.Clear(); !ok {
41 | errs = append(errs, errors.New("couldn't empty in flight connections"))
42 | }
43 |
44 | return multierr.Combine(errs...)
45 | }
46 |
--------------------------------------------------------------------------------
/conn/connpool/onetimepool.go:
--------------------------------------------------------------------------------
1 | package connpool
2 |
3 | import (
4 | "context"
5 | "net"
6 | "sync"
7 | )
8 |
9 | type Dialer interface {
10 | Dial(ctx context.Context, network, address string) (net.Conn, error)
11 | }
12 |
13 | // OneTimePool is a pool designed to create continous bare connections that are for one time only usage
14 | type OneTimePool struct {
15 | address string
16 | idleConnections chan net.Conn
17 | InFlightConns *InFlightConns
18 | ctx context.Context
19 | cancel context.CancelFunc
20 | Dialer Dialer
21 | mx sync.RWMutex
22 | }
23 |
24 | func NewOneTimePool(ctx context.Context, address string, poolSize int) (*OneTimePool, error) {
25 | idleConnections := make(chan net.Conn, poolSize)
26 | inFlightConns, err := NewInFlightConns()
27 | if err != nil {
28 | return nil, err
29 | }
30 | pool := &OneTimePool{
31 | address: address,
32 | idleConnections: idleConnections,
33 | InFlightConns: inFlightConns,
34 | }
35 | if ctx == nil {
36 | ctx = context.Background()
37 | }
38 | pool.ctx, pool.cancel = context.WithCancel(ctx)
39 | return pool, nil
40 | }
41 |
42 | // Acquire acquires an idle connection from the pool
43 | func (p *OneTimePool) Acquire(c context.Context) (net.Conn, error) {
44 | select {
45 | case <-p.ctx.Done():
46 | return nil, p.ctx.Err()
47 | case <-c.Done():
48 | return nil, c.Err()
49 | case conn := <-p.idleConnections:
50 | p.InFlightConns.Remove(conn)
51 | return conn, nil
52 | }
53 | }
54 |
55 | func (p *OneTimePool) Run() error {
56 | for {
57 | select {
58 | case <-p.ctx.Done():
59 | return p.ctx.Err()
60 | default:
61 | var (
62 | conn net.Conn
63 | err error
64 | )
65 | p.mx.RLock()
66 | hasDialer := p.Dialer != nil
67 | p.mx.RUnlock()
68 |
69 | if hasDialer {
70 | p.mx.RLock()
71 | conn, err = p.Dialer.Dial(p.ctx, "tcp", p.address)
72 | p.mx.RUnlock()
73 | } else {
74 | conn, err = net.Dial("tcp", p.address)
75 | }
76 | if err == nil {
77 | p.InFlightConns.Add(conn)
78 | select {
79 | case <-p.ctx.Done():
80 | return p.ctx.Err()
81 | case p.idleConnections <- conn:
82 | }
83 | }
84 | }
85 | }
86 | }
87 |
88 | func (p *OneTimePool) Close() error {
89 | p.cancel()
90 |
91 | // remove dialer references
92 | p.mx.Lock()
93 | p.Dialer = nil
94 | p.mx.Unlock()
95 |
96 | return p.InFlightConns.Close()
97 | }
98 |
--------------------------------------------------------------------------------
/consts/errors.go:
--------------------------------------------------------------------------------
1 | package consts
2 |
3 | import "errors"
4 |
5 | var (
6 | ErrNotSupported = errors.New("not supported")
7 | )
8 |
--------------------------------------------------------------------------------
/context/NContext.go:
--------------------------------------------------------------------------------
1 | package contextutil
2 |
3 | import "context"
4 |
5 | // A problematic situation when implementing context in a function
6 | // is when that function has more than one return values
7 | // if function has only one return value we can safely wrap it something like this
8 | /*
9 | func DoSomething() error {}
10 | ch := make(chan error)
11 | go func() {
12 | ch <- DoSomething()
13 | }()
14 | select {
15 | case err := <-ch:
16 | // handle error
17 | case <-ctx.Done():
18 | // handle context cancelation
19 | }
20 | */
21 | // but what if we have more than one value to return?
22 | // we can use generics and a struct and that is what we are doing here
23 | // here we use struct and generics to store return values of a function
24 | // instead of storing it in a []interface{}
25 |
26 | type twoValueCtx[T1 any, T2 any] struct {
27 | var1 T1
28 | var2 T2
29 | }
30 |
31 | type threeValueCtx[T1 any, T2 any, T3 any] struct {
32 | var1 T1
33 | var2 T2
34 | var3 T3
35 | }
36 |
37 | // ExecFunc implements context for a function which has no return values
38 | // and executes that function. if context is cancelled before function returns
39 | // it will return context error otherwise it will return nil
40 | func ExecFunc(ctx context.Context, fn func()) error {
41 | ch := make(chan struct{})
42 | go func() {
43 | fn()
44 | ch <- struct{}{}
45 | }()
46 | select {
47 | case <-ch:
48 | return nil
49 | case <-ctx.Done():
50 | return ctx.Err()
51 | }
52 | }
53 |
54 | // ExecFuncWithTwoReturns wraps a function which has two return values given that last one is error
55 | // and executes that function in a goroutine there by implementing context
56 | // if context is cancelled before function returns it will return context error
57 | // otherwise it will return function's return values
58 | func ExecFuncWithTwoReturns[T1 any](ctx context.Context, fn func() (T1, error)) (T1, error) {
59 | ch := make(chan twoValueCtx[T1, error])
60 | go func() {
61 | x, y := fn()
62 | ch <- twoValueCtx[T1, error]{var1: x, var2: y}
63 | }()
64 | select {
65 | case <-ctx.Done():
66 | var tmp T1
67 | return tmp, ctx.Err()
68 | case v := <-ch:
69 | return v.var1, v.var2
70 | }
71 | }
72 |
73 | // ExecFuncWithThreeReturns wraps a function which has three return values given that last one is error
74 | // and executes that function in a goroutine there by implementing context
75 | // if context is cancelled before function returns it will return context error
76 | // otherwise it will return function's return values
77 | func ExecFuncWithThreeReturns[T1 any, T2 any](ctx context.Context, fn func() (T1, T2, error)) (T1, T2, error) {
78 | ch := make(chan threeValueCtx[T1, T2, error])
79 | go func() {
80 | x, y, z := fn()
81 | ch <- threeValueCtx[T1, T2, error]{var1: x, var2: y, var3: z}
82 | }()
83 | select {
84 | case <-ctx.Done():
85 | var tmp1 T1
86 | var tmp2 T2
87 | return tmp1, tmp2, ctx.Err()
88 | case v := <-ch:
89 | return v.var1, v.var2, v.var3
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/context/context.go:
--------------------------------------------------------------------------------
1 | package contextutil
2 |
3 | import (
4 | "context"
5 | "errors"
6 | )
7 |
8 | var ErrIncorrectNumberOfItems = errors.New("number of items is not even")
9 |
10 | var DefaultContext = context.TODO()
11 |
12 | type ContextArg string
13 |
14 | // WithValues combines multiple key-value into an existing context
15 | func WithValues(ctx context.Context, keyValue ...ContextArg) (context.Context, error) {
16 | if len(keyValue)%2 != 0 {
17 | return ctx, ErrIncorrectNumberOfItems
18 | }
19 |
20 | for i := 0; i < len(keyValue)-1; i++ {
21 | ctx = context.WithValue(ctx, keyValue[i], keyValue[i+1]) //nolint
22 | }
23 | return ctx, nil
24 | }
25 |
26 | // ValueOrDefault returns default context if given is nil (using interface to avoid static check reporting)
27 | func ValueOrDefault(value interface{}) context.Context {
28 | if ctx, ok := value.(context.Context); ok && ctx != nil {
29 | return ctx
30 | }
31 |
32 | return DefaultContext
33 | }
34 |
--------------------------------------------------------------------------------
/context/context_test.go:
--------------------------------------------------------------------------------
1 | package contextutil
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestWithValues(t *testing.T) {
11 | type testCase struct {
12 | name string
13 | keyValue []ContextArg
14 | expectedError error
15 | expectedValue map[ContextArg]ContextArg
16 | }
17 |
18 | var testCases = []testCase{
19 | {
20 | name: "even number of key-value pairs",
21 | keyValue: []ContextArg{"key1", "value1", "key2", "value2"},
22 | expectedError: nil,
23 | expectedValue: map[ContextArg]ContextArg{"key1": "value1", "key2": "value2"},
24 | },
25 | {
26 | name: "odd number of key-value pairs",
27 | keyValue: []ContextArg{"key1", "value1", "key2"},
28 | expectedError: ErrIncorrectNumberOfItems,
29 | expectedValue: map[ContextArg]ContextArg{},
30 | },
31 | {
32 | name: "overwriting values",
33 | keyValue: []ContextArg{"key1", "value1", "key1", "newValue"},
34 | expectedError: nil,
35 | expectedValue: map[ContextArg]ContextArg{"key1": "newValue"},
36 | },
37 | }
38 | ctx := context.Background()
39 | for _, tc := range testCases {
40 | t.Run(tc.name, func(t *testing.T) {
41 | newCtx, err := WithValues(ctx, tc.keyValue...)
42 | if tc.expectedError != nil {
43 | require.ErrorIs(t, err, tc.expectedError)
44 | require.Equal(t, ctx, newCtx, "Expected original context to be returned")
45 | }
46 |
47 | for key, expectedVal := range tc.expectedValue {
48 | if val := newCtx.Value(key); val != expectedVal {
49 | t.Errorf("Expected %s but got %v", expectedVal, val)
50 | }
51 | }
52 | })
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/conversion/conversion.go:
--------------------------------------------------------------------------------
1 | package conversion
2 |
3 | import "unsafe"
4 |
5 | func Bytes(s string) []byte {
6 | return unsafe.Slice(unsafe.StringData(s), len(s))
7 | }
8 |
9 | func String(b []byte) string {
10 | if len(b) == 0 {
11 | return ""
12 | }
13 | return unsafe.String(unsafe.SliceData(b), len(b))
14 | }
15 |
--------------------------------------------------------------------------------
/conversion/conversion_test.go:
--------------------------------------------------------------------------------
1 | package conversion
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 | )
7 |
8 | func TestBytes(t *testing.T) {
9 | testCases := []struct {
10 | input string
11 | expected []byte
12 | }{
13 | {"test", []byte("test")},
14 | {"", []byte("")},
15 | }
16 |
17 | for _, tc := range testCases {
18 | result := Bytes(tc.input)
19 | if !bytes.Equal(result, tc.expected) {
20 | t.Errorf("Expected %v, but got %v", tc.expected, result)
21 | }
22 | }
23 | }
24 |
25 | func TestString(t *testing.T) {
26 | testCases := []struct {
27 | input []byte
28 | expected string
29 | }{
30 | {[]byte("test"), "test"},
31 | {[]byte(""), ""},
32 | }
33 |
34 | for _, tc := range testCases {
35 | result := String(tc.input)
36 | if result != tc.expected {
37 | t.Errorf("Expected %s, but got %s", tc.expected, result)
38 | }
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/crypto/README.md:
--------------------------------------------------------------------------------
1 | # cryptoutil
2 | The package contains various helpers about crypto
--------------------------------------------------------------------------------
/crypto/hash.go:
--------------------------------------------------------------------------------
1 | package cryptoutil
2 |
3 | import (
4 | "crypto/sha256"
5 | "encoding/hex"
6 | )
7 |
8 | func SHA256Sum(data interface{}) string {
9 | hasher := sha256.New()
10 | if v, ok := data.([]byte); ok {
11 | hasher.Write(v)
12 | } else if v, ok := data.(string); ok {
13 | hasher.Write([]byte(v))
14 | } else {
15 | return ""
16 | }
17 |
18 | return hex.EncodeToString(hasher.Sum(nil))
19 | }
20 |
--------------------------------------------------------------------------------
/crypto/hash_test.go:
--------------------------------------------------------------------------------
1 | package cryptoutil
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | )
8 |
9 | func TestSHA256Sum(t *testing.T) {
10 | tests := map[string]string{
11 | "test": "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08",
12 | "test1": "1b4f0e9851971998e732078544c96b36c3d01cedf7caa332359d6f1d83567014",
13 | }
14 | for item, hash := range tests {
15 | require.Equal(t, hash, SHA256Sum(item), "hash is different")
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/crypto/jarm/jarm.go:
--------------------------------------------------------------------------------
1 | package jarm
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "strings"
8 | "time"
9 |
10 | gojarm "github.com/hdm/jarm-go"
11 | connpool "github.com/projectdiscovery/utils/conn/connpool"
12 | )
13 |
14 | // PoolCount defines how many connection are kept in the pool
15 | var PoolCount = 3
16 |
17 | // fingerprint probes a single host/port
18 | func HashWithDialer(dialer connpool.Dialer, host string, port int, duration int) (string, error) {
19 | var results []string
20 | addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
21 |
22 | timeout := time.Duration(duration) * time.Second
23 | ctx, cancel := context.WithTimeout(context.Background(), (time.Duration(duration*PoolCount) * time.Second))
24 | defer cancel()
25 |
26 | // using connection pool as we need multiple probes
27 | pool, err := connpool.NewOneTimePool(ctx, addr, PoolCount)
28 | if err != nil {
29 | return "", err
30 | }
31 | pool.Dialer = dialer
32 |
33 | defer func() { _ = pool.Close() }()
34 | go func() { _ = pool.Run() }()
35 |
36 | for _, probe := range gojarm.GetProbes(host, port) {
37 | conn, err := pool.Acquire(ctx)
38 | if err != nil {
39 | continue
40 | }
41 | if conn == nil {
42 | continue
43 | }
44 | _ = conn.SetWriteDeadline(time.Now().Add(timeout))
45 | _, err = conn.Write(gojarm.BuildProbe(probe))
46 | if err != nil {
47 | results = append(results, "")
48 | _ = conn.Close()
49 | continue
50 | }
51 | _ = conn.SetReadDeadline(time.Now().Add(timeout))
52 | buff := make([]byte, 1484)
53 | _, _ = conn.Read(buff)
54 | _ = conn.Close()
55 | ans, err := gojarm.ParseServerHello(buff, probe)
56 | if err != nil {
57 | results = append(results, "")
58 | continue
59 | }
60 | results = append(results, ans)
61 | }
62 | hash := gojarm.RawHashToFuzzyHash(strings.Join(results, ","))
63 | return hash, nil
64 | }
65 |
--------------------------------------------------------------------------------
/dedupe/dedupe.go:
--------------------------------------------------------------------------------
1 | package dedupe
2 |
3 | // MaxInMemoryDedupeSize (default : 100 MB)
4 | var MaxInMemoryDedupeSize = 100 * 1024 * 1024
5 |
6 | type DedupeBackend interface {
7 | // Upsert add/update key to backend/database
8 | Upsert(elem string) bool
9 | // Execute given callback on each element while iterating
10 | IterCallback(callback func(elem string))
11 | // Cleanup cleans any residuals after deduping
12 | Cleanup()
13 | }
14 |
15 | // Dedupe is string deduplication type which removes
16 | // all duplicates if
17 | type Dedupe struct {
18 | receive <-chan string
19 | backend DedupeBackend
20 | }
21 |
22 | // Option is a type for variadic options in Drain
23 | type Option func(val string)
24 |
25 | // WithUnique is an option to send unique values to the provided channel
26 | func WithUnique(ch chan<- string) Option {
27 | return func(val string) {
28 | ch <- val
29 | }
30 | }
31 |
32 | // Drains channel and tries to dedupe it
33 | func (d *Dedupe) Drain(opts ...Option) {
34 | for val := range d.receive {
35 | if unique := d.backend.Upsert(val); unique {
36 | for _, opt := range opts {
37 | opt(val)
38 | }
39 | }
40 | }
41 | }
42 |
43 | // GetResults iterates over dedupe storage and returns results
44 | func (d *Dedupe) GetResults() <-chan string {
45 | send := make(chan string, 100)
46 | go func() {
47 | defer close(send)
48 | d.backend.IterCallback(func(elem string) {
49 | send <- elem
50 | })
51 | d.backend.Cleanup()
52 | }()
53 | return send
54 | }
55 |
56 | // NewDedupe returns a dedupe instance which removes all duplicates
57 | // Note: If byteLen is not correct/specified alterx may consume lot of memory
58 | func NewDedupe(ch <-chan string, byteLen int) *Dedupe {
59 | d := &Dedupe{
60 | receive: ch,
61 | }
62 | if byteLen <= MaxInMemoryDedupeSize {
63 | d.backend = NewMapBackend()
64 | } else {
65 | // gologger print a info message here
66 | d.backend = NewLevelDBBackend()
67 | }
68 | return d
69 | }
70 |
--------------------------------------------------------------------------------
/dedupe/dedupe_test.go:
--------------------------------------------------------------------------------
1 | package dedupe
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestDedupe(t *testing.T) {
8 | t.Run("MapBackend", func(t *testing.T) {
9 | receiveCh := make(chan string, 10)
10 | dedupe := NewDedupe(receiveCh, 1)
11 |
12 | receiveCh <- "test1"
13 | receiveCh <- "test2"
14 | receiveCh <- "test1"
15 | close(receiveCh)
16 |
17 | resultCh := make(chan string, 10)
18 | dedupe.Drain(WithUnique(resultCh))
19 | close(resultCh)
20 |
21 | results := collectResults(resultCh)
22 |
23 | if len(results) != 2 {
24 | t.Fatalf("expected 2 unique items, got %d", len(results))
25 | }
26 | })
27 |
28 | t.Run("LevelDBBackend", func(t *testing.T) {
29 | receiveCh := make(chan string, 10)
30 | dedupe := NewDedupe(receiveCh, MaxInMemoryDedupeSize+1)
31 |
32 | receiveCh <- "testA"
33 | receiveCh <- "testB"
34 | receiveCh <- "testA"
35 | close(receiveCh)
36 |
37 | resultCh := make(chan string, 10)
38 | dedupe.Drain(WithUnique(resultCh))
39 | close(resultCh)
40 |
41 | results := collectResults(resultCh)
42 |
43 | if len(results) != 2 {
44 | t.Fatalf("expected 2 unique items, got %d", len(results))
45 | }
46 | })
47 |
48 | t.Run("Drain", func(t *testing.T) {
49 | receiveCh := make(chan string, 10)
50 | dedupe := NewDedupe(receiveCh, 1)
51 |
52 | receiveCh <- "testX"
53 | receiveCh <- "testY"
54 | receiveCh <- "testX"
55 | close(receiveCh)
56 |
57 | resultCh := make(chan string, 10)
58 | dedupe.Drain(WithUnique(resultCh))
59 | close(resultCh)
60 |
61 | results := collectResults(resultCh)
62 |
63 | if len(results) != 2 {
64 | t.Fatalf("expected 2 unique items, got %d", len(results))
65 | }
66 | })
67 | }
68 |
69 | func collectResults(ch <-chan string) []string {
70 | var results []string
71 | for item := range ch {
72 | results = append(results, item)
73 | }
74 | return results
75 | }
76 |
--------------------------------------------------------------------------------
/dedupe/leveldb.go:
--------------------------------------------------------------------------------
1 | package dedupe
2 |
3 | import (
4 | "github.com/projectdiscovery/gologger"
5 | "github.com/projectdiscovery/hmap/store/hybrid"
6 | )
7 |
8 | type LevelDBBackend struct {
9 | storage *hybrid.HybridMap
10 | }
11 |
12 | func NewLevelDBBackend() *LevelDBBackend {
13 | l := &LevelDBBackend{}
14 | db, err := hybrid.New(hybrid.DefaultDiskOptions)
15 | if err != nil {
16 | gologger.Fatal().Msgf("failed to create temp dir for alterx dedupe got: %v", err)
17 | }
18 | l.storage = db
19 | return l
20 | }
21 |
22 | func (l *LevelDBBackend) Upsert(elem string) bool {
23 | _, exists := l.storage.Get(elem)
24 | if exists {
25 | return false
26 | }
27 |
28 | if err := l.storage.Set(elem, nil); err != nil {
29 | gologger.Error().Msgf("dedupe: leveldb: got %v while writing %v", err, elem)
30 | return false
31 | }
32 | return true
33 | }
34 |
35 | func (l *LevelDBBackend) IterCallback(callback func(elem string)) {
36 | l.storage.Scan(func(k, _ []byte) error {
37 | callback(string(k))
38 | return nil
39 | })
40 | }
41 |
42 | func (l *LevelDBBackend) Cleanup() {
43 | _ = l.storage.Close()
44 | }
45 |
--------------------------------------------------------------------------------
/dedupe/map.go:
--------------------------------------------------------------------------------
1 | package dedupe
2 |
3 | import "runtime/debug"
4 |
5 | type MapBackend struct {
6 | storage map[string]struct{}
7 | }
8 |
9 | func NewMapBackend() *MapBackend {
10 | return &MapBackend{storage: map[string]struct{}{}}
11 | }
12 |
13 | func (m *MapBackend) Upsert(elem string) bool {
14 | if _, exists := m.storage[elem]; exists {
15 | return false
16 | }
17 | m.storage[elem] = struct{}{}
18 | return true
19 | }
20 |
21 | func (m *MapBackend) IterCallback(callback func(elem string)) {
22 | for k := range m.storage {
23 | callback(k)
24 | }
25 | }
26 |
27 | func (m *MapBackend) Cleanup() {
28 | m.storage = nil
29 | // By default GC doesnot release buffered/allocated memory
30 | // since there always is possibilitly of needing it again/immediately
31 | // and releases memory in chunks
32 | // debug.FreeOSMemory forces GC to release allocated memory at once
33 | debug.FreeOSMemory()
34 | }
35 |
--------------------------------------------------------------------------------
/dns/dnsutil.go:
--------------------------------------------------------------------------------
1 | package dnsutil
2 |
3 | import (
4 | stringsutil "github.com/projectdiscovery/utils/strings"
5 | "github.com/weppos/publicsuffix-go/publicsuffix"
6 | )
7 |
8 | // Split takes a domain name and decomposes it into its subdomain and domain components.
9 | // The function returns the subdomain, the domain, and an error if the decomposition process fails.
10 | //
11 | // For example:
12 | // - Input: "http://www.example.com"
13 | // - Output: "www", "example.com", nil
14 | func Split(name string) (string, string, error) {
15 | name = stringsutil.TrimPrefixAny(name, "http://", "https://")
16 | dn, err := publicsuffix.ParseFromListWithOptions(publicsuffix.DefaultList, name, publicsuffix.DefaultFindOptions)
17 | if err != nil {
18 | return "", "", err
19 | }
20 |
21 | return dn.TRD, dn.SLD + "." + dn.TLD, nil
22 | }
23 |
--------------------------------------------------------------------------------
/dns/dnsutil_test.go:
--------------------------------------------------------------------------------
1 | package dnsutil
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestSplit(t *testing.T) {
8 | tests := []struct {
9 | name string
10 | subdomain string
11 | domain string
12 | expectErr bool
13 | }{
14 | {"www.example.com", "www", "example.com", false},
15 | {"http://www.example.com", "www", "example.com", false},
16 | {"example.com", "", "example.com", false},
17 | {"sub.sub.example.co.uk", "sub.sub", "example.co.uk", false},
18 | {"invalid_domain", "", "", true},
19 | {"", "", "", true},
20 | }
21 |
22 | for _, test := range tests {
23 | subdomain, domain, err := Split(test.name)
24 | if test.expectErr && err == nil {
25 | t.Errorf("expected error for domain %s, but got none", test.name)
26 | }
27 | if !test.expectErr && err != nil {
28 | t.Errorf("did not expect error for domain %s, but got %v", test.name, err)
29 | }
30 | if subdomain != test.subdomain {
31 | t.Errorf("expected subdomain %s for domain %s, but got %s", test.subdomain, test.name, subdomain)
32 | }
33 | if domain != test.domain {
34 | t.Errorf("expected domain %s for domain %s, but got %s", test.domain, test.name, domain)
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/env/env.go:
--------------------------------------------------------------------------------
1 | package env
2 |
3 | import (
4 | "os"
5 | "strconv"
6 | "strings"
7 | "time"
8 | )
9 |
10 | var (
11 | TLS_VERIFY = os.Getenv("TLS_VERIFY") == "true"
12 | DEBUG = os.Getenv("DEBUG") == "true"
13 | )
14 |
15 | // ExpandWithEnv updates string variables to their corresponding environment values.
16 | // If the variables does not exist, they're set to empty strings.
17 | func ExpandWithEnv(variables ...*string) {
18 | for _, variable := range variables {
19 | if variable == nil {
20 | continue
21 | }
22 | *variable = os.Getenv(strings.TrimPrefix(*variable, "$"))
23 | }
24 | }
25 |
26 | // EnvType is a type that can be used as a type for environment variables.
27 | type EnvType interface {
28 | ~string | ~int | ~bool | ~float64 | time.Duration | ~rune
29 | }
30 |
31 | // GetEnvOrDefault returns the value of the environment variable or the default value if the variable is not set.
32 | // in requested type.
33 | func GetEnvOrDefault[T EnvType](key string, defaultValue T) T {
34 | value := os.Getenv(key)
35 | if value == "" {
36 | return defaultValue
37 | }
38 | switch any(defaultValue).(type) {
39 | case string:
40 | return any(value).(T)
41 | case int:
42 | intVal, err := strconv.Atoi(value)
43 | if err != nil || value == "" {
44 | return defaultValue
45 | }
46 | return any(intVal).(T)
47 | case bool:
48 | boolVal, err := strconv.ParseBool(value)
49 | if err != nil || value == "" {
50 | return defaultValue
51 | }
52 | return any(boolVal).(T)
53 | case float64:
54 | floatVal, err := strconv.ParseFloat(value, 64)
55 | if err != nil || value == "" {
56 | return defaultValue
57 | }
58 | return any(floatVal).(T)
59 | case time.Duration:
60 | durationVal, err := time.ParseDuration(value)
61 | if err != nil || value == "" {
62 | return defaultValue
63 | }
64 | return any(durationVal).(T)
65 | }
66 | return defaultValue
67 | }
68 |
--------------------------------------------------------------------------------
/env/env_test.go:
--------------------------------------------------------------------------------
1 | package env
2 |
3 | import (
4 | "os"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestExpandWithEnv(t *testing.T) {
10 | testEnvVar := "TEST_VAR"
11 | testEnvValue := "TestValue"
12 | os.Setenv(testEnvVar, testEnvValue)
13 | defer os.Unsetenv(testEnvVar)
14 |
15 | tests := []struct {
16 | input string
17 | expected string
18 | name string
19 | }{
20 | {"$" + testEnvVar, testEnvValue, "Existing env variable"},
21 | {"$NON_EXISTENT_VAR", "", "Non-existent env variable"},
22 | {"NOT_AN_ENV_VAR", "", "Not prefixed with $"},
23 | {"", "", "Empty string"},
24 | }
25 |
26 | for _, tt := range tests {
27 | t.Run(tt.name, func(t *testing.T) {
28 | ExpandWithEnv(&tt.input)
29 | if tt.input != tt.expected {
30 | t.Errorf("got %q, want %q", tt.input, tt.expected)
31 | }
32 | })
33 | }
34 | }
35 |
36 | func TestExpandWithEnvNilInput(t *testing.T) {
37 | defer func() {
38 | if r := recover(); r != nil {
39 | t.Errorf("The code panicked with %v", r)
40 | }
41 | }()
42 |
43 | var nilVar *string = nil
44 | ExpandWithEnv(nilVar)
45 | }
46 |
47 | func TestGetEnvOrDefault(t *testing.T) {
48 | // Test for string
49 | os.Setenv("TEST_STRING", "test")
50 | resultString := GetEnvOrDefault("TEST_STRING", "default")
51 | if resultString != "test" {
52 | t.Errorf("Expected 'test', got %s", resultString)
53 | }
54 |
55 | // Test for int
56 | os.Setenv("TEST_INT", "123")
57 | resultInt := GetEnvOrDefault("TEST_INT", 0)
58 | if resultInt != 123 {
59 | t.Errorf("Expected 123, got %d", resultInt)
60 | }
61 |
62 | // Test for bool
63 | os.Setenv("TEST_BOOL", "true")
64 | resultBool := GetEnvOrDefault("TEST_BOOL", false)
65 | if resultBool != true {
66 | t.Errorf("Expected true, got %t", resultBool)
67 | }
68 |
69 | // Test for float64
70 | os.Setenv("TEST_FLOAT", "1.23")
71 | resultFloat := GetEnvOrDefault("TEST_FLOAT", 0.0)
72 | if resultFloat != 1.23 {
73 | t.Errorf("Expected 1.23, got %f", resultFloat)
74 | }
75 |
76 | // Test for time.Duration
77 | os.Setenv("TEST_DURATION", "1h")
78 | resultDuration := GetEnvOrDefault("TEST_DURATION", time.Duration(0))
79 | if resultDuration != time.Hour {
80 | t.Errorf("Expected 1h, got %s", resultDuration)
81 | }
82 |
83 | // Test for default value
84 | resultDefault := GetEnvOrDefault("NON_EXISTING", "default")
85 | if resultDefault != "default" {
86 | t.Errorf("Expected 'default', got %s", resultDefault)
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/errkit/README.md:
--------------------------------------------------------------------------------
1 | # errkit
2 |
3 | why errkit when we already have errorutil ?
4 |
5 | ----
6 |
7 | Introduced a year ago, `errorutil` aimed to capture error stacks for identifying deeply nested errors. However, its approach deviates from Go's error handling paradigm. In Go, libraries like "errors", "pkg/errors", and "uber.go/multierr" avoid using the `.Error()` method directly. Instead, they wrap errors with helper structs that implement specific interfaces, facilitating error chain traversal and the use of helper functions like `.Cause() error` or `.Unwrap() error` or `errors.Is()`. Contrarily, `errorutil` marshals errors to strings, which is incompatible with Go's error handling paradigm. Over time, the use of `errorutil` has become cumbersome due to its inability to replace any error package seamlessly and its lack of support for idiomatic error propagation or traversal in Go.
8 |
9 |
10 | `errkit` is a new error library that addresses the shortcomings of `errorutil`. It offers the following features:
11 |
12 | - Seamless replacement for existing error packages, requiring no syntax changes or refactoring:
13 | - `errors` package
14 | - `pkg/errors` package (now deprecated)
15 | - `uber/multierr` package
16 | - `errkit` is compatible with all known Go error handling implementations. It can parse errors from any library and works with existing error handling libraries and helper functions like `Is()`, `As()`, `Cause()`, and more.
17 | - `errkit` is Go idiomatic and adheres to the Go error handling paradigm.
18 | - `errkit` supports attributes for structured error information or logging using `slog.Attr` (optional).
19 | - `errkit` implements and categorizes errors into different kinds, as detailed below.
20 | - `ErrKindNetworkTemporary`
21 | - `ErrKindNetworkPermanent`
22 | - `ErrKindDeadline`
23 | - Custom kinds via `ErrKind` interface
24 | - `errkit` provides helper functions for structured error logging using `SlogAttrs` and `SlogAttrGroup`.
25 | - `errkit` offers helper functions to implement public or user-facing errors by using error kinds interface.
26 |
27 |
28 | **Attributes Support**
29 |
30 | `errkit` supports optional error wrapping with attributes `slog.Attr` for structured error logging, providing a more organized approach to error logging than string wrapping.
31 |
32 | ```go
33 | // normal way of error propogating through nested stack
34 | err := errkit.New("i/o timeout")
35 |
36 | // xyz.go
37 | err := errkit.Wrap(err,"failed to connect %s",addr)
38 |
39 | // abc.go
40 | err := errkit.Wrap(err,"error occured when downloading %s",xyz)
41 | ```
42 |
43 | with attributes support you can do following
44 |
45 | ```go
46 | // normal way of error propogating through nested stack
47 | err := errkit.New("i/o timeout")
48 |
49 | // xyz.go
50 | err = errkit.WithAttr(err,slog.Any("resource",domain))
51 |
52 | // abc.go
53 | err = errkit.WithAttr(err,slog.Any("action","download"))
54 | ```
55 |
56 | ## Note
57 |
58 | To keep errors concise and avoid unnecessary allocations, message wrapping and attributes count have a max depth set to 3. Adding more will not panic but will be simply ignored. This is configurable using the MAX_ERR_DEPTH env variable (default 3).
--------------------------------------------------------------------------------
/errkit/interfaces.go:
--------------------------------------------------------------------------------
1 | package errkit
2 |
3 | import "encoding/json"
4 |
5 | var (
6 | _ json.Marshaler = &ErrorX{}
7 | _ JoinedError = &ErrorX{}
8 | _ CauseError = &ErrorX{}
9 | _ ComparableError = &ErrorX{}
10 | _ error = &ErrorX{}
11 | )
12 |
13 | // below contains all interfaces that are implemented by ErrorX which
14 | // makes it compatible with other error packages
15 |
16 | // JoinedError is implemented by errors that are joined by Join
17 | type JoinedError interface {
18 | // Unwrap returns the underlying error
19 | Unwrap() []error
20 | }
21 |
22 | // CauseError is implemented by errors that have a cause
23 | type CauseError interface {
24 | // Cause return the original error that caused this without any wrapping
25 | Cause() error
26 | }
27 |
28 | // ComparableError is implemented by errors that can be compared
29 | type ComparableError interface {
30 | // Is checks if current error contains given error
31 | Is(err error) bool
32 | }
33 |
34 | // WrappedError is implemented by errors that are wrapped
35 | type WrappedError interface {
36 | // Unwrap returns the underlying error
37 | Unwrap() error
38 | }
39 |
--------------------------------------------------------------------------------
/errors/err_test.go:
--------------------------------------------------------------------------------
1 | package errorutil_test
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | "testing"
7 |
8 | errors "github.com/projectdiscovery/utils/errors"
9 | )
10 |
11 | func TestErrorEqual(t *testing.T) {
12 | err1 := fmt.Errorf("error init x")
13 | err2 := errors.NewWithErr(err1)
14 | err3 := errors.NewWithTag("testing", "error init")
15 | var errnil error
16 |
17 | if !errors.IsAny(err1, err2, errnil) {
18 | t.Errorf("expected errors to be equal")
19 | }
20 | if errors.IsAny(err1, err3, errnil) {
21 | t.Errorf("expected error to be not equal")
22 | }
23 | }
24 |
25 | func TestWrapWithNil(t *testing.T) {
26 | err1 := errors.NewWithTag("niltest", "non nil error").WithLevel(errors.Fatal)
27 | var errx error
28 |
29 | if errors.WrapwithNil(errx, err1) != nil {
30 | t.Errorf("when base error is nil ")
31 | }
32 | }
33 |
34 | func TestStackTrace(t *testing.T) {
35 | err := errors.New("base error")
36 | relay := func(err error) error {
37 | return err
38 | }
39 | errx := relay(err)
40 |
41 | t.Run("teststack", func(t *testing.T) {
42 | if strings.Contains(errx.Error(), "captureStack") {
43 | t.Errorf("stacktrace should be disabled by default")
44 | }
45 | errors.ShowStackTrace = true
46 | if !strings.Contains(errx.Error(), "captureStack") {
47 | t.Errorf("missing stacktrace got %v", errx.Error())
48 | }
49 | })
50 | }
51 |
52 | func TestErrorCallback(t *testing.T) {
53 | callbackExecuted := false
54 |
55 | err := errors.NewWithTag("callback", "got error").WithCallback(func(level errors.ErrorLevel, err string, tags ...string) {
56 | if level != errors.Runtime {
57 | t.Errorf("Default error level should be Runtime")
58 | }
59 | if tags[0] != "callback" {
60 | t.Errorf("missing callback")
61 | }
62 | callbackExecuted = true
63 | })
64 |
65 | errval := err.Error()
66 |
67 | if !strings.Contains(errval, "callback") || !strings.Contains(errval, "got error") || !strings.Contains(errval, "RUNTIME") {
68 | t.Errorf("error content missing expected values `callback,got error and Runtime` in error value but got %v", errval)
69 | }
70 |
71 | if !callbackExecuted {
72 | t.Errorf("error callback failed to execute")
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/errors/err_with_fmt.go:
--------------------------------------------------------------------------------
1 | package errorutil
2 |
3 | import (
4 | "fmt"
5 | )
6 |
7 | // ErrWithFmt is a simplified version of err holding a default format
8 | type ErrWithFmt struct {
9 | fmt string
10 | }
11 |
12 | // Wrapf wraps given message
13 | func (e *ErrWithFmt) Msgf(args ...any) error {
14 | return fmt.Errorf(e.fmt, args...)
15 | }
16 |
17 | func (e *ErrWithFmt) Error() {
18 | panic("ErrWithFmt is a format holder")
19 | }
20 |
21 | func NewWithFmt(fmt string) ErrWithFmt {
22 | if fmt == "" {
23 | panic("format can't be empty")
24 | }
25 |
26 | return ErrWithFmt{fmt: fmt}
27 | }
28 |
--------------------------------------------------------------------------------
/errors/err_with_fmt_test.go:
--------------------------------------------------------------------------------
1 | package errorutil_test
2 |
3 | import (
4 | "testing"
5 |
6 | errors "github.com/projectdiscovery/utils/errors"
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestErrWithFmt(t *testing.T) {
11 | errBase := errors.NewWithFmt("error: %s")
12 | errWithMsg1 := errBase.Msgf("test1")
13 | errWithMsg2 := errBase.Msgf("test2")
14 |
15 | require.Equal(t, "error: test1", errWithMsg1.Error())
16 | require.Equal(t, "error: test2", errWithMsg2.Error())
17 | }
18 |
--------------------------------------------------------------------------------
/errors/errinterface.go:
--------------------------------------------------------------------------------
1 | package errorutil
2 |
3 | // Error is enriched version of normal error
4 | // with tags, stacktrace and other methods
5 | type Error interface {
6 | // WithTag assigns tag[s] to Error
7 | WithTag(tag ...string) Error
8 | // WithLevel assigns given ErrorLevel
9 | WithLevel(level ErrorLevel) Error
10 | // Error is interface method of 'error'
11 | Error() string
12 | // Wraps existing error with errors (skips if passed error is nil)
13 | Wrap(err ...error) Error
14 | // Msgf wraps error with given message
15 | Msgf(format string, args ...any) Error
16 | // Equal Checks Equality of errors
17 | Equal(err ...error) bool
18 | // WithCallback execute ErrCallback function when Error is triggered
19 | WithCallback(handle ErrCallback) Error
20 | }
21 |
--------------------------------------------------------------------------------
/errors/errlevel.go:
--------------------------------------------------------------------------------
1 | package errorutil
2 |
3 | type ErrorLevel uint
4 |
5 | const (
6 | Panic ErrorLevel = iota
7 | Fatal
8 | Runtime // Default
9 | )
10 |
11 | func (l ErrorLevel) String() string {
12 | switch l {
13 | case Panic:
14 | return "PANIC"
15 | case Fatal:
16 | return "FATAL"
17 | case Runtime:
18 | return "RUNTIME"
19 | }
20 | return "RUNTIME" //default is runtime
21 | }
22 |
--------------------------------------------------------------------------------
/errors/errors.go:
--------------------------------------------------------------------------------
1 | package errorutil
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "os"
9 | "strings"
10 | )
11 |
12 | // IsAny checks if err is not nil and matches any one of errxx errors
13 | // if match successful returns true else false
14 | // Note: no unwrapping is done here
15 | func IsAny(err error, errxx ...error) bool {
16 | if err == nil {
17 | return false
18 | }
19 | if enrichedErr, ok := err.(Error); ok {
20 | for _, v := range errxx {
21 | if enrichedErr.Equal(v) {
22 | return true
23 | }
24 | }
25 | } else {
26 | for _, v := range errxx {
27 | // check if v is an enriched error
28 | if ee, ok := v.(Error); ok && ee.Equal(err) {
29 | return true
30 | }
31 | // check standard error equality
32 | if strings.EqualFold(err.Error(), fmt.Sprint(v)) {
33 | return true
34 | }
35 | }
36 | }
37 | return false
38 | }
39 |
40 | // WrapfWithNil returns nil if error is nil but if err is not nil
41 | // wraps error with given msg unlike errors.Wrapf
42 | func WrapfWithNil(err error, format string, args ...any) Error {
43 | if err == nil {
44 | return nil
45 | }
46 | ee := NewWithErr(err)
47 | return ee.Msgf(format, args...)
48 | }
49 |
50 | // WrapwithNil returns nil if err is nil but wraps it with given
51 | // errors continuously if it is not nil
52 | func WrapwithNil(err error, errx ...error) Error {
53 | if err == nil {
54 | return nil
55 | }
56 | ee := NewWithErr(err)
57 | return ee.Wrap(errx...)
58 | }
59 |
60 | // IsTimeout checks if error is timeout error
61 | func IsTimeout(err error) bool {
62 | var net net.Error
63 | return (errors.As(err, &net) && net.Timeout()) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, os.ErrDeadlineExceeded)
64 | }
65 |
--------------------------------------------------------------------------------
/exec/README.md:
--------------------------------------------------------------------------------
1 | # executil
2 | The package contains various helpers to interact binary execution
--------------------------------------------------------------------------------
/exec/executil_test.go:
--------------------------------------------------------------------------------
1 | package executil
2 |
3 | import (
4 | "runtime"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | var newLineMarker string
11 |
12 | func init() {
13 | if runtime.GOOS == "windows" {
14 | newLineMarker = "\r\n"
15 | } else {
16 | newLineMarker = "\n"
17 | }
18 | }
19 |
20 | func TestRun(t *testing.T) {
21 | // try to run the echo command
22 | s, err := Run("echo test")
23 | require.Nil(t, err, "failed execution", err)
24 | require.Equal(t, "test"+newLineMarker, s, "output doesn't contain expected result", s)
25 | }
26 |
27 | func TestRunAdv(t *testing.T) {
28 | testcases := []struct {
29 | GOOS string // OS
30 | Command string
31 | Expected string // expected output
32 | Contains string // expected output contains
33 | }{
34 | // Tests With Flags
35 | {"darwin", "uname -s", "Darwin", ""},
36 | {"linux", "uname -s", "Linux", ""},
37 | {"windows", "cmd /c ver", "", "Windows"},
38 | // Tests With CMD PIPE
39 | {"windows", `systeminfo | findstr /B /C:"OS Name"`, "", "Windows"},
40 | {"darwin", `sw_vers | grep -i "ProductName"`, "", "macOS"},
41 | {"linux", `uname -a | cut -d " " -f 1`, "Linux", ""},
42 | // Other Shell Specific Features
43 | {"windows", `cmd /c " echo This && echo Works"`, "This \r\nWorks", ""},
44 | {"linux", "true && echo This Works", "This Works", ""},
45 | {"darwin", "true && echo This Works", "This Works", ""},
46 | }
47 |
48 | runFunc := func(cmd string, expected string, contains string) {
49 | s, err := Run(cmd)
50 | require.Nilf(t, err, "%v failed to execute", cmd)
51 | if expected != "" {
52 | require.Equal(t, expected+newLineMarker, s)
53 | } else if contains != "" {
54 | require.Contains(t, s, contains)
55 | } else {
56 | t.Logf("Malformed test case : %v", cmd)
57 | }
58 | t.Logf("Test Successful: %v", cmd)
59 | }
60 |
61 | for _, v := range testcases {
62 | switch v.GOOS {
63 | case "windows":
64 | if runtime.GOOS != "windows" {
65 | continue
66 | }
67 | runFunc(v.Command, v.Expected, v.Contains)
68 | case "darwin":
69 | if runtime.GOOS != "darwin" {
70 | continue
71 | }
72 | runFunc(v.Command, v.Expected, v.Contains)
73 | case "linux":
74 | if runtime.GOOS != "linux" {
75 | continue
76 | }
77 | runFunc(v.Command, v.Expected, v.Contains)
78 | default:
79 | t.Logf("No Unit Test Available for this platform")
80 |
81 | }
82 | }
83 | }
84 |
85 | func TestRunSafe(t *testing.T) {
86 | _, err := RunSafe(`whoami | grep Hello`)
87 | require.Error(t, err)
88 | }
89 |
90 | func TestRunSh(t *testing.T) {
91 | if runtime.GOOS == "windows" {
92 | return
93 | }
94 | // try to run the echo command
95 | s, err := RunSh("echo", "test")
96 | require.Nil(t, err, "failed execution", err)
97 | require.Equal(t, "test"+newLineMarker, s, "output doesn't contain expected result", s)
98 | }
99 |
100 | func TestRunPS(t *testing.T) {
101 | if runtime.GOOS != "windows" {
102 | return
103 | }
104 | // run powershell command (runs in both ps1 and ps2)
105 | s, err := RunPS("get-host")
106 | require.Nil(t, err, "failed execution", err)
107 | require.Contains(t, s, "Microsoft.PowerShell", "failed to run powershell command get-host")
108 | }
109 |
--------------------------------------------------------------------------------
/file/README.md:
--------------------------------------------------------------------------------
1 | # fileutil
2 | The package contains various helpers to interact with files
--------------------------------------------------------------------------------
/file/clean_test.go:
--------------------------------------------------------------------------------
1 | package fileutil
2 |
3 | import (
4 | "io"
5 | "os"
6 | "strings"
7 | "testing"
8 | )
9 |
10 | func FuzzSafeOpen(f *testing.F) {
11 |
12 | // ==========setup==========
13 |
14 | bin, err := os.ReadFile("tests/path-traversal.txt")
15 | if err != nil {
16 | f.Fatalf("failed to read file: %s", err)
17 | }
18 |
19 | fuzzPayloads := strings.Split(string(bin), "\n")
20 |
21 | file, err := os.CreateTemp("", "*")
22 | if err != nil {
23 | f.Fatal(err)
24 | }
25 | _, _ = file.WriteString("pwned!")
26 | _ = file.Close()
27 |
28 | defer func(tmp string) {
29 | if err = os.Remove(tmp); err != nil {
30 | panic(err)
31 | }
32 | }(file.Name())
33 |
34 | // ==========fuzzing==========
35 |
36 | for _, payload := range fuzzPayloads {
37 | f.Add(strings.ReplaceAll(payload, "{FILE}", f.Name()), f.Name())
38 |
39 | }
40 | f.Fuzz(func(t *testing.T, fuzzPath string, targetPath string) {
41 | cleaned, err := CleanPath(fuzzPath)
42 | if err != nil {
43 | // Ignore errors
44 | return
45 | }
46 | if cleaned != targetPath {
47 | // cleaned path is different from target file
48 | // so verify if 'path' is actually valid and not random chars
49 | result, err := SafeOpen(cleaned)
50 | if err != nil {
51 | // Ignore errors
52 | return
53 | }
54 | defer result.Close()
55 | bin, _ := io.ReadAll(result)
56 | if string(bin) == "pwned!" {
57 | t.Fatalf("pwned! cleaned=%s ,input=%s", cleaned, fuzzPath)
58 | }
59 | }
60 |
61 | })
62 | }
63 |
--------------------------------------------------------------------------------
/file/tests/empty_lines.txt:
--------------------------------------------------------------------------------
1 | test
2 | test1
3 |
4 |
5 |
6 |
7 | test2
8 |
9 |
10 | test3
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | test4
19 |
--------------------------------------------------------------------------------
/file/tests/path-traversal.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/projectdiscovery/utils/a1fc48cb9b6cc2049ff14863224320e47eb2f936/file/tests/path-traversal.txt
--------------------------------------------------------------------------------
/file/tests/pipe_separator.txt:
--------------------------------------------------------------------------------
1 | test|test1|test2|test3|test4
2 |
--------------------------------------------------------------------------------
/file/tests/standard.txt:
--------------------------------------------------------------------------------
1 | test
2 | test1
3 | test2
4 | test3
5 | test4
6 |
--------------------------------------------------------------------------------
/folder/README.md:
--------------------------------------------------------------------------------
1 | # folderutil
2 | The package contains various helpers to interact with folders
3 |
4 | ## UserConfigDirOrDefault
5 |
6 | UserConfigDirOrDefault returns the default root directory to use for user-specific configuration data. Users should create their own application-specific subdirectory within this one and use that.
7 |
8 | On Unix systems, it returns $XDG_CONFIG_HOME as specified by https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html if non-empty, else $HOME/.config. On Darwin, it returns $HOME/Library/Application Support. On Windows, it returns %AppData%. On Plan 9, it returns $home/lib.
9 |
10 | If the location cannot be determined (for example, $HOME is not defined), then it will return given value as default.
--------------------------------------------------------------------------------
/folder/folderutil_linux_test.go:
--------------------------------------------------------------------------------
1 | //go:build !windows
2 |
3 | package folderutil
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestPathInfo(t *testing.T) {
12 | got, err := NewPathInfo("/a/b/c")
13 | assert.Nil(t, err)
14 | gotPaths, err := got.Paths()
15 | assert.Nil(t, err)
16 | assert.EqualValues(t, []string{"/", "/a", "/a/b", "/a/b/c"}, gotPaths)
17 | gotMeshPaths, err := got.MeshWith("test.txt")
18 | assert.Nil(t, err)
19 | assert.EqualValues(t, []string{"/test.txt", "/a/test.txt", "/a/b/test.txt", "/a/b/c/test.txt"}, gotMeshPaths)
20 | }
21 |
--------------------------------------------------------------------------------
/folder/folderutil_win_test.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package folderutil
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestPathInfo(t *testing.T) {
12 | got, err := NewPathInfo("c:\\a\\b\\c")
13 | assert.Nil(t, err)
14 | gotPaths, err := got.Paths()
15 | assert.Nil(t, err)
16 | assert.EqualValues(t, []string{".", "c:\\", "c:\\a", "c:\\a\\b", "c:\\a\\b\\c"}, gotPaths)
17 | gotMeshPaths, err := got.MeshWith("test.txt")
18 | assert.Nil(t, err)
19 | assert.EqualValues(t, []string{"test.txt", "c:\\test.txt", "c:\\a\\test.txt", "c:\\a\\b\\test.txt", "c:\\a\\b\\c\\test.txt"}, gotMeshPaths)
20 | }
21 |
--------------------------------------------------------------------------------
/folder/std_dirs.go:
--------------------------------------------------------------------------------
1 | package folderutil
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "os/user"
7 | "path/filepath"
8 | )
9 |
10 | // Below Contains utils for standard directories
11 | // which should be used by tools to store data
12 | // and configuration files respectively
13 |
14 | // HomeDirOrDefault tries to obtain the user's home directory and
15 | // returns the default if it cannot be obtained.
16 | func HomeDirOrDefault(defaultDirectory string) string {
17 | if homeDir, err := os.UserHomeDir(); err == nil && IsWritable(homeDir) {
18 | return homeDir
19 | }
20 | if user, err := user.Current(); err == nil && IsWritable(user.HomeDir) {
21 | return user.HomeDir
22 | }
23 | return defaultDirectory
24 | }
25 |
26 | // UserConfigDirOrDefault returns the user config directory or defaultConfigDir in case of error
27 | func UserConfigDirOrDefault(defaultConfigDir string) string {
28 | userConfigDir, err := os.UserConfigDir()
29 | if err != nil {
30 | return defaultConfigDir
31 | }
32 | return userConfigDir
33 | }
34 |
35 | // AppConfigDirOrDefault returns the app config directory
36 | func AppConfigDirOrDefault(defaultAppConfigDir string, toolName string) string {
37 | userConfigDir := UserConfigDirOrDefault("")
38 | if userConfigDir == "" {
39 | return filepath.Join(defaultAppConfigDir, toolName)
40 | }
41 | return filepath.Join(userConfigDir, toolName)
42 | }
43 |
44 | // AppCacheDirOrDefault returns the user cache directory or defaultCacheDir in case of error
45 | func AppCacheDirOrDefault(defaultCacheDir string, toolName string) string {
46 | userCacheDir, err := os.UserCacheDir()
47 | if err != nil || userCacheDir == "" {
48 | return filepath.Join(defaultCacheDir, toolName)
49 | }
50 | return filepath.Join(userCacheDir, toolName)
51 | }
52 |
53 | // Prints the standard directories for a tool
54 | func PrintStdDirs(toolName string) {
55 | appConfigDir := AppConfigDirOrDefault("", toolName)
56 | appCacheDir := AppCacheDirOrDefault("", toolName)
57 | fmt.Printf("[+] %v %-13v: %v\n", toolName, "AppConfigDir", appConfigDir)
58 | fmt.Printf("[+] %v %-13v: %v\n", toolName, "AppCacheDir", appCacheDir)
59 | }
60 |
--------------------------------------------------------------------------------
/generic/generic.go:
--------------------------------------------------------------------------------
1 | package generic
2 |
3 | import (
4 | "bytes"
5 | "encoding/gob"
6 | )
7 |
8 | // EqualsAny checks if a base value of type T is equal to
9 | // any of the other values of type T provided as arguments.
10 | func EqualsAny[T comparable](base T, all ...T) bool {
11 | for _, v := range all {
12 | if v == base {
13 | return true
14 | }
15 | }
16 | return false
17 | }
18 |
19 | // EqualsAll checks if a base value of type T is equal to all of the
20 | // other values of type T provided as arguments.
21 | func EqualsAll[T comparable](base T, all ...T) bool {
22 | if len(all) == 0 {
23 | return false
24 | }
25 | for _, v := range all {
26 | if v != base {
27 | return false
28 | }
29 | }
30 | return true
31 | }
32 |
33 | // SizeOf returns the approx size of a variable in bytes
34 | func ApproxSizeOf[T any](v T) (int, error) {
35 | buf := new(bytes.Buffer)
36 | if err := gob.NewEncoder(buf).Encode(v); err != nil {
37 | return 0, err
38 | }
39 | return buf.Len(), nil
40 | }
41 |
--------------------------------------------------------------------------------
/generic/generic_test.go:
--------------------------------------------------------------------------------
1 | package generic
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | )
8 |
9 | func TestEqualsAnyInt(t *testing.T) {
10 | testCases := []struct {
11 | Base int
12 | All []int
13 | Expected bool
14 | }{
15 | {3, []int{1, 2, 3, 4}, true},
16 | {5, []int{1, 2, 3, 4}, false},
17 | {0, []int{0}, true},
18 | {0, []int{1}, false},
19 | }
20 |
21 | for _, tc := range testCases {
22 | actual := EqualsAny(tc.Base, tc.All...)
23 | require.Equal(t, tc.Expected, actual)
24 | }
25 | }
26 |
27 | func TestEqualsAnyString(t *testing.T) {
28 | testCases := []struct {
29 | Base string
30 | All []string
31 | Expected bool
32 | }{
33 | {"test", []string{"test1", "test", "test2", "test3"}, true},
34 | {"test", []string{"test1", "test2", "test3", "test4"}, false},
35 | {"", []string{""}, true},
36 | {"", []string{"not empty"}, false},
37 | }
38 |
39 | for _, tc := range testCases {
40 | actual := EqualsAny(tc.Base, tc.All...)
41 | require.Equal(t, tc.Expected, actual)
42 | }
43 | }
44 |
45 | func TestEqualsAllInt(t *testing.T) {
46 | testCases := []struct {
47 | Base int
48 | All []int
49 | Expected bool
50 | }{
51 | {5, []int{5, 5, 5, 5}, true},
52 | {5, []int{1, 2, 3, 4}, false},
53 | {0, []int{}, false},
54 | }
55 |
56 | for _, tc := range testCases {
57 | actual := EqualsAll(tc.Base, tc.All...)
58 | require.Equal(t, tc.Expected, actual)
59 | }
60 | }
61 |
62 | func TestEqualsAllString(t *testing.T) {
63 | testCases := []struct {
64 | Base string
65 | All []string
66 | Expected bool
67 | }{
68 | {"test", []string{"test", "test", "test", "test"}, true},
69 | {"test", []string{"test", "test1", "test2", "test3"}, false},
70 | {"", []string{}, false},
71 | }
72 |
73 | for _, tc := range testCases {
74 | actual := EqualsAll(tc.Base, tc.All...)
75 | require.Equal(t, tc.Expected, actual)
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/generic/lockable.go:
--------------------------------------------------------------------------------
1 | package generic
2 |
3 | import (
4 | "sync"
5 | )
6 |
7 | type Lockable[K any] struct {
8 | V K
9 | sync.RWMutex
10 | }
11 |
12 | func (v *Lockable[K]) Do(f func(val K)) {
13 | v.Lock()
14 | defer v.Unlock()
15 | f(v.V)
16 | }
17 |
18 | func WithLock[K any](val K) *Lockable[K] {
19 | return &Lockable[K]{V: val}
20 | }
21 |
--------------------------------------------------------------------------------
/generic/lockable_test.go:
--------------------------------------------------------------------------------
1 | package generic
2 |
3 | import (
4 | "sync"
5 | "testing"
6 | )
7 |
8 | func TestDo(t *testing.T) {
9 | val := 10
10 | l := WithLock(val)
11 | l.Do(func(v int) {
12 | if v != val {
13 | t.Errorf("Expected %d, got %d", val, v)
14 | }
15 | })
16 | }
17 |
18 | func TestLockableConcurrency(t *testing.T) {
19 | l := WithLock(0)
20 |
21 | var wg sync.WaitGroup
22 |
23 | for i := 0; i < 100; i++ {
24 | wg.Add(1)
25 | go func() {
26 | defer wg.Done()
27 | for j := 0; j < 1000; j++ {
28 | l.Do(func(v int) {
29 | v++
30 | l.V = v
31 | })
32 | }
33 | }()
34 | }
35 |
36 | wg.Wait()
37 |
38 | if l.V != 100*1000 {
39 | t.Errorf("Expected counter to be %d, but got %d", 100*1000, l.V)
40 | }
41 | }
42 |
43 | func TestLockableStringManipulation(t *testing.T) {
44 | str := "initial"
45 | l := WithLock(str)
46 |
47 | l.Do(func(s string) {
48 | s += " - updated"
49 | l.V = s
50 | })
51 |
52 | if l.V != "initial - updated" {
53 | t.Errorf("Expected 'initial - updated', got '%s'", str)
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/global/max_threads.go:
--------------------------------------------------------------------------------
1 | package global
2 |
3 | import (
4 | "os"
5 | "strconv"
6 |
7 | "github.com/projectdiscovery/utils/sysutil"
8 | )
9 |
10 | const OS_MAX_THREADS_ENV = "OS_MAX_THREADS"
11 |
12 | func init() {
13 | handleOSMaxThreads()
14 | }
15 |
16 | func handleOSMaxThreads() {
17 | osMaxThreads := os.Getenv(OS_MAX_THREADS_ENV)
18 | if osMaxThreads == "" {
19 | return
20 | }
21 | if value, err := strconv.Atoi(osMaxThreads); err == nil && value > 0 {
22 | _ = sysutil.SetMaxThreads(value)
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/healthcheck/connection.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "strconv"
7 | "time"
8 | )
9 |
10 | type ConnectionInfo struct {
11 | Host string
12 | Successful bool
13 | Message string
14 | Error error
15 | }
16 |
17 | func CheckConnection(host string, port int, protocol string, timeout time.Duration) ConnectionInfo {
18 | address := net.JoinHostPort(host, strconv.Itoa(port))
19 | conn, err := net.DialTimeout(protocol, address, timeout)
20 | if conn != nil {
21 | conn.Close()
22 | }
23 |
24 | return ConnectionInfo{
25 | Host: host,
26 | Successful: err == nil,
27 | Message: fmt.Sprintf("%s Connect (%s:%v): %s", protocol, host, port, "Successful"),
28 | Error: err,
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/healthcheck/connection_test.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestCheckConnection(t *testing.T) {
11 | t.Run("Test successful connection", func(t *testing.T) {
12 | info := CheckConnection("scanme.sh", 80, "tcp", 1*time.Second)
13 | assert.NoError(t, info.Error)
14 | assert.True(t, info.Successful)
15 | assert.Equal(t, "scanme.sh", info.Host)
16 | assert.Contains(t, info.Message, "Successful")
17 | })
18 |
19 | t.Run("Test unsuccessful connection", func(t *testing.T) {
20 | info := CheckConnection("invalid.website", 80, "tcp", 1*time.Second)
21 | assert.Error(t, info.Error)
22 | })
23 |
24 | t.Run("Test timeout connection", func(t *testing.T) {
25 | info := CheckConnection("192.0.2.0", 80, "tcp", 1*time.Millisecond)
26 | assert.Error(t, info.Error)
27 | })
28 | }
29 |
--------------------------------------------------------------------------------
/healthcheck/dns.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "context"
5 | "net"
6 | "strings"
7 | )
8 |
9 | type DnsResolveInfo struct {
10 | Host string
11 | Resolver string
12 | Successful bool
13 | IPAddresses []net.IPAddr
14 | Error error
15 | }
16 |
17 | func DnsResolve(host string, resolver string) DnsResolveInfo {
18 | ipAddresses, err := getIPAddresses(host, resolver)
19 |
20 | return DnsResolveInfo{
21 | Host: host,
22 | Resolver: resolver,
23 | Successful: err == nil,
24 | IPAddresses: ipAddresses,
25 | Error: err,
26 | }
27 | }
28 |
29 | func getIPAddresses(name, dnsServer string) ([]net.IPAddr, error) {
30 | if !strings.Contains(dnsServer, ":") {
31 | dnsServer = dnsServer + ":53"
32 | }
33 |
34 | resolver := net.Resolver{
35 | PreferGo: true, Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
36 | d := net.Dialer{}
37 | return d.DialContext(ctx, network, dnsServer)
38 | }}
39 |
40 | resolvedIPs, err := resolver.LookupIPAddr(context.Background(), name)
41 | if err != nil {
42 | return nil, err
43 | }
44 |
45 | return resolvedIPs, nil
46 | }
47 |
--------------------------------------------------------------------------------
/healthcheck/dns_test.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestDnsResolve(t *testing.T) {
10 | t.Run("Successful resolution", func(t *testing.T) {
11 | info := DnsResolve("scanme.sh", "1.1.1.1")
12 | assert.NoError(t, info.Error)
13 | assert.True(t, info.Successful)
14 | assert.Equal(t, "scanme.sh", info.Host)
15 | assert.Equal(t, "1.1.1.1", info.Resolver)
16 | assert.NotEmpty(t, info.IPAddresses)
17 | })
18 |
19 | t.Run("Unsuccessful resolution due to invalid host", func(t *testing.T) {
20 | info := DnsResolve("invalid.website", "1.1.1.1")
21 | assert.Error(t, info.Error)
22 | })
23 |
24 | t.Run("Unsuccessful resolution due to invalid resolver", func(t *testing.T) {
25 | info := DnsResolve("google.com", "invalid.resolver")
26 | assert.Error(t, info.Error)
27 | })
28 | }
29 |
--------------------------------------------------------------------------------
/healthcheck/environment.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "os"
5 | "runtime"
6 |
7 | "github.com/projectdiscovery/fdmax"
8 | iputil "github.com/projectdiscovery/utils/ip"
9 | permissionutil "github.com/projectdiscovery/utils/permission"
10 | router "github.com/projectdiscovery/utils/routing"
11 | )
12 |
13 | type EnvironmentInfo struct {
14 | ExternalIPv4 string
15 | Admin bool
16 | Arch string
17 | Compiler string
18 | GoVersion string
19 | OSName string
20 | ProgramVersion string
21 | OutboundIPv4 string
22 | OutboundIPv6 string
23 | Ulimit Ulimit
24 | PathEnvVar string
25 | Error error
26 | }
27 |
28 | type Ulimit struct {
29 | Current uint64
30 | Max uint64
31 | }
32 |
33 | func CollectEnvironmentInfo(appVersion string) EnvironmentInfo {
34 | externalIPv4, _ := iputil.WhatsMyIP()
35 | outboundIPv4, outboundIPv6, _ := router.GetOutboundIPs()
36 |
37 | ulimit := Ulimit{}
38 | limit, err := fdmax.Get()
39 | if err == nil {
40 | ulimit.Current = limit.Current
41 | ulimit.Max = limit.Max
42 | }
43 |
44 | return EnvironmentInfo{
45 | ExternalIPv4: externalIPv4,
46 | Admin: permissionutil.IsRoot,
47 | Arch: runtime.GOARCH,
48 | Compiler: runtime.Compiler,
49 | GoVersion: runtime.Version(),
50 | OSName: runtime.GOOS,
51 | ProgramVersion: appVersion,
52 | OutboundIPv4: outboundIPv4.String(),
53 | OutboundIPv6: outboundIPv6.String(),
54 | Ulimit: ulimit,
55 | PathEnvVar: os.Getenv("PATH"),
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/healthcheck/environment_test.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "os"
5 | "runtime"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestCollectEnvironmentInfo(t *testing.T) {
12 | t.Run("Collect Environment Info", func(t *testing.T) {
13 | programVersion := "1.0.0"
14 |
15 | environmentInfo := CollectEnvironmentInfo(programVersion)
16 | assert.NoError(t, environmentInfo.Error, "Error should not have occurred when collecting environment info")
17 | assert.NotNil(t, environmentInfo, "EnvironmentInfo should not be nil")
18 | assert.Equal(t, programVersion, environmentInfo.ProgramVersion, "Program version should match program version")
19 | assert.Equal(t, runtime.GOARCH, environmentInfo.Arch, "Architecture should match runtime")
20 | assert.Equal(t, runtime.Compiler, environmentInfo.Compiler, "Compiler should match runtime")
21 | assert.Equal(t, runtime.Version(), environmentInfo.GoVersion, "Go version should match runtime")
22 | assert.Equal(t, runtime.GOOS, environmentInfo.OSName, "OS name should match runtime")
23 | assert.Equal(t, os.Getenv("PATH"), environmentInfo.PathEnvVar, "PATH environment variable should match system PATH")
24 | })
25 | }
26 |
--------------------------------------------------------------------------------
/healthcheck/healthcheck.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "path/filepath"
5 |
6 | fileutil "github.com/projectdiscovery/utils/file"
7 | folderutil "github.com/projectdiscovery/utils/folder"
8 | )
9 |
10 | var (
11 | DefaultPathsToCheckPermission = []string{filepath.Join(folderutil.HomeDirOrDefault(""), ".config", fileutil.ExecutableName())}
12 | DefaultHostsToCheckConnectivity = []string{"scanme.sh"}
13 | DefaultResolver = "1.1.1.1:53"
14 | )
15 |
16 | type HealthCheckInfo struct {
17 | EnvironmentInfo EnvironmentInfo
18 | PathPermissions []PathPermission
19 | DnsResolveInfos []DnsResolveInfo
20 | }
21 |
22 | type Options struct {
23 | Paths []string
24 | Hosts []string
25 | Resolver string
26 | }
27 |
28 | var DefaultOptions = Options{
29 | Paths: DefaultPathsToCheckPermission,
30 | Hosts: DefaultHostsToCheckConnectivity,
31 | Resolver: DefaultResolver,
32 | }
33 |
34 | func Do(programVersion string, options *Options) (healthCheckInfo HealthCheckInfo) {
35 | if options == nil {
36 | options = &DefaultOptions
37 | }
38 | healthCheckInfo.EnvironmentInfo = CollectEnvironmentInfo(programVersion)
39 | for _, path := range options.Paths {
40 | healthCheckInfo.PathPermissions = append(healthCheckInfo.PathPermissions, CheckPathPermission(path))
41 | }
42 | for _, host := range options.Hosts {
43 | healthCheckInfo.DnsResolveInfos = append(healthCheckInfo.DnsResolveInfos, DnsResolve(host, options.Resolver))
44 | }
45 | return
46 | }
47 |
--------------------------------------------------------------------------------
/healthcheck/path_permission.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "errors"
5 |
6 | fileutil "github.com/projectdiscovery/utils/file"
7 | )
8 |
9 | type PathPermission struct {
10 | path string
11 | isReadable bool
12 | isWritable bool
13 | Error error
14 | }
15 |
16 | // CheckPathPermission checks the permissions of the given file or directory.
17 | func CheckPathPermission(path string) (pathPermission PathPermission) {
18 | pathPermission.path = path
19 | if !fileutil.FileExists(path) {
20 | pathPermission.Error = errors.New("file or directory doesn't exist at " + path)
21 | return
22 | }
23 |
24 | pathPermission.isReadable, _ = fileutil.IsReadable(path)
25 | pathPermission.isWritable, _ = fileutil.IsWriteable(path)
26 |
27 | return
28 | }
29 |
--------------------------------------------------------------------------------
/healthcheck/path_permission_test.go:
--------------------------------------------------------------------------------
1 | package healthcheck
2 |
3 | import (
4 | "os"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestCheckPathPermission(t *testing.T) {
11 | t.Run("file with read and write permissions", func(t *testing.T) {
12 | filename := "testfile_read_write.txt"
13 | _, err := os.Create(filename)
14 | defer os.Remove(filename)
15 | assert.NoError(t, err)
16 |
17 | permission := CheckPathPermission(filename)
18 | assert.NoError(t, permission.Error)
19 | assert.Equal(t, true, permission.isReadable)
20 | assert.Equal(t, true, permission.isWritable)
21 | })
22 |
23 | t.Run("non-existing file", func(t *testing.T) {
24 | filename := "non_existing_file.txt"
25 | permission := CheckPathPermission(filename)
26 | assert.Error(t, permission.Error)
27 | })
28 |
29 | t.Run("file without write permission", func(t *testing.T) {
30 | filename := "testfile_read_only.txt"
31 | file, err := os.Create(filename)
32 | assert.NoError(t, err)
33 |
34 | err = file.Chmod(0444) // read-only permissions
35 | assert.NoError(t, err)
36 |
37 | defer os.Remove(filename)
38 | permission := CheckPathPermission(filename)
39 |
40 | assert.NoError(t, permission.Error)
41 | assert.Equal(t, true, permission.isReadable)
42 | assert.Equal(t, false, permission.isWritable)
43 | })
44 | }
45 |
--------------------------------------------------------------------------------
/http/README.md:
--------------------------------------------------------------------------------
1 | # httputil
2 | The package contains various helpers related to http protocol
--------------------------------------------------------------------------------
/http/chain.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "net/http"
5 | "net/http/httputil"
6 | )
7 |
8 | // ChainItem request=>response
9 | // Deprecated: use ResponseChain instead which is more efficient and lazy
10 | type ChainItem struct {
11 | Request []byte
12 | Response []byte
13 | StatusCode int
14 | Location string
15 | RequestURL string
16 | }
17 |
18 | // GetChain if redirects
19 | // Deprecated: use ResponseChain instead which is more efficient and lazy
20 | func GetChain(r *http.Response) (chain []ChainItem, err error) {
21 | lastresp := r
22 | for lastresp != nil {
23 | lastreq := lastresp.Request
24 | lastreqDump, err := httputil.DumpRequest(lastreq, false)
25 | if err != nil {
26 | return nil, err
27 | }
28 | lastrespDump, err := httputil.DumpResponse(lastresp, false)
29 | if err != nil {
30 | return nil, err
31 | }
32 | var location string
33 | if l, err := lastresp.Location(); err == nil {
34 | location = l.String()
35 | }
36 | requestURL := lastreq.URL.String()
37 | chain = append(chain, ChainItem{Request: lastreqDump, Response: lastrespDump, StatusCode: lastresp.StatusCode, Location: location, RequestURL: requestURL})
38 | // process next
39 | lastresp = lastreq.Response
40 | }
41 | // reverse the slice in order to have the chain in progressive order
42 | for i, j := 0, len(chain)-1; i < j; i, j = i+1, j-1 {
43 | chain[i], chain[j] = chain[j], chain[i]
44 | }
45 | return
46 | }
47 |
--------------------------------------------------------------------------------
/http/httputil.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "net/http/httputil"
9 | )
10 |
11 | // AllHTTPMethods contains all available HTTP methods
12 | func AllHTTPMethods() []string {
13 | return []string{
14 | http.MethodGet,
15 | http.MethodHead,
16 | http.MethodPost,
17 | http.MethodPut,
18 | http.MethodPatch,
19 | http.MethodDelete,
20 | http.MethodConnect,
21 | http.MethodOptions,
22 | http.MethodTrace,
23 | }
24 | }
25 |
26 | // DumpRequest to string
27 | func DumpRequest(req *http.Request) (string, error) {
28 | dump, err := httputil.DumpRequestOut(req, true)
29 |
30 | return string(dump), err
31 | }
32 |
33 | // DumpResponseHeadersAndRaw returns http headers and response as strings
34 | func DumpResponseHeadersAndRaw(resp *http.Response) (headers, fullresp []byte, err error) {
35 | // httputil.DumpResponse does not work with websockets
36 | if resp.StatusCode >= http.StatusContinue && resp.StatusCode <= http.StatusEarlyHints {
37 | raw := resp.Status + "\n"
38 | for h, v := range resp.Header {
39 | raw += fmt.Sprintf("%s: %s\n", h, v)
40 | }
41 | return []byte(raw), []byte(raw), nil
42 | }
43 | headers, err = httputil.DumpResponse(resp, false)
44 | if err != nil {
45 | return
46 | }
47 | // logic same as httputil.DumpResponse(resp, true) but handles
48 | // the edge case when we get both error and data on reading resp.Body
49 | var buf1, buf2 bytes.Buffer
50 | b := resp.Body
51 | if _, err = buf1.ReadFrom(b); err != nil {
52 | if buf1.Len() <= 0 {
53 | return
54 | }
55 | }
56 | if err == nil {
57 | _ = b.Close()
58 | }
59 |
60 | // rewind the body to allow full dump
61 | resp.Body = io.NopCloser(bytes.NewReader(buf1.Bytes()))
62 | err = resp.Write(&buf2)
63 | fullresp = buf2.Bytes()
64 |
65 | // rewind once more to allow further reuses
66 | resp.Body = io.NopCloser(bytes.NewReader(buf1.Bytes()))
67 | return
68 | }
69 |
--------------------------------------------------------------------------------
/http/httputil_test.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/require"
12 | )
13 |
14 | func TestDumpRequest(t *testing.T) {
15 | req := httptest.NewRequest("GET", "http://example.com/foo", nil)
16 |
17 | reqdump, err := DumpRequest(req)
18 | require.Nil(t, err)
19 | exp := "GET /foo HTTP/1.1\r\nHost: example.com\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n"
20 | require.Equal(t, exp, reqdump)
21 | }
22 |
23 | func TestDumpResponseHeadersAndRaw(t *testing.T) {
24 | expectedResponseBody := "Hello, client"
25 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26 | w.Header().Del("Date")
27 | fmt.Fprintln(w, expectedResponseBody)
28 | }))
29 | defer ts.Close()
30 |
31 | res, err := http.Get(ts.URL)
32 | require.Nil(t, err)
33 |
34 | headersdumpB, respdumpB, err := DumpResponseHeadersAndRaw(res)
35 | headersdump := string(headersdumpB)
36 | respdump := string(respdumpB)
37 | headersdump = strings.Split(headersdump, "Date")[0]
38 | tokens := strings.Split(respdump, "\r\n")
39 | respdump = ""
40 | for _, token := range tokens {
41 | if !strings.HasPrefix(token, "Date") {
42 | respdump += token + "\r\n"
43 | }
44 | }
45 | require.Nil(t, err)
46 | headers := "HTTP/1.1 200 OK\r\nContent-Length: 14\r\nContent-Type: text/plain; charset=utf-8\r\n"
47 | resp := "HTTP/1.1 200 OK\r\nContent-Length: 14\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nHello, client\n\r\n"
48 | require.Equal(t, headers, headersdump)
49 | require.Equal(t, resp, respdump)
50 |
51 | // ensure that the response body is still readable
52 | respBody, err := io.ReadAll(res.Body)
53 | require.Nil(t, err)
54 | require.Equal(t, expectedResponseBody+"\n", string(respBody))
55 | }
56 |
--------------------------------------------------------------------------------
/http/internal.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "io"
7 | "net/http"
8 | "strings"
9 | )
10 |
11 | // implementations copied from stdlib
12 |
13 | // errNoBody is a sentinel error value used by failureToReadBody so we
14 | // can detect that the lack of body was intentional.
15 | var errNoBody = errors.New("sentinel error value")
16 |
17 | // failureToReadBody is an io.ReadCloser that just returns errNoBody on
18 | // Read. It's swapped in when we don't actually want to consume
19 | // the body, but need a non-nil one, and want to distinguish the
20 | // error from reading the dummy body.
21 | type failureToReadBody struct{}
22 |
23 | func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
24 | func (failureToReadBody) Close() error { return nil }
25 |
26 | // emptyBody is an instance of empty reader.
27 | var emptyBody = io.NopCloser(strings.NewReader(""))
28 |
29 | // drainBody reads all of b to memory and then returns two equivalent
30 | // ReadClosers yielding the same bytes.
31 | //
32 | // It returns an error if the initial slurp of all bytes fails. It does not attempt
33 | // to make the returned ReadClosers have identical error-matching behavior.
34 | func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
35 | if b == nil || b == http.NoBody {
36 | // No copying needed. Preserve the magic sentinel meaning of NoBody.
37 | return http.NoBody, http.NoBody, nil
38 | }
39 | var buf bytes.Buffer
40 | if _, err = buf.ReadFrom(b); err != nil {
41 | return nil, b, err
42 | }
43 | if err = b.Close(); err != nil {
44 | return nil, b, err
45 | }
46 | return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil
47 | }
48 |
--------------------------------------------------------------------------------
/http/response.go:
--------------------------------------------------------------------------------
1 | package httputil
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "net/http"
8 |
9 | "github.com/docker/go-units"
10 | )
11 |
12 | var (
13 | MaxBodyRead, _ = units.FromHumanSize("4mb")
14 | )
15 |
16 | // DumpResponseIntoBuffer dumps a http response without allocating a new buffer
17 | // for the response body.
18 | func DumpResponseIntoBuffer(resp *http.Response, body bool, buff *bytes.Buffer) (err error) {
19 | if resp == nil {
20 | return fmt.Errorf("response is nil")
21 | }
22 | save := resp.Body
23 | savecl := resp.ContentLength
24 |
25 | if !body {
26 | // For content length of zero. Make sure the body is an empty
27 | // reader, instead of returning error through failureToReadBody{}.
28 | if resp.ContentLength == 0 {
29 | resp.Body = emptyBody
30 | } else {
31 | resp.Body = failureToReadBody{}
32 | }
33 | } else if resp.Body == nil {
34 | resp.Body = emptyBody
35 | } else {
36 | save, resp.Body, err = drainBody(resp.Body)
37 | if err != nil {
38 | return err
39 | }
40 | }
41 | err = resp.Write(buff)
42 | if err == errNoBody {
43 | err = nil
44 | }
45 | resp.Body = save
46 | resp.ContentLength = savecl
47 | return
48 | }
49 |
50 | // DrainResponseBody drains the response body and closes it.
51 | func DrainResponseBody(resp *http.Response) {
52 | defer resp.Body.Close()
53 | // don't reuse connection and just close if body length is more than 2 * MaxBodyRead
54 | // to avoid DOS
55 | _, _ = io.CopyN(io.Discard, resp.Body, 2*MaxBodyRead)
56 | }
57 |
--------------------------------------------------------------------------------
/io/io.go:
--------------------------------------------------------------------------------
1 | package ioutil
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "sync"
7 | )
8 |
9 | // SafeWriter is a thread-safe wrapper for io.Writer
10 | type SafeWriter struct {
11 | writer io.Writer // The underlying writer
12 | mutex *sync.Mutex // Mutex for ensuring thread-safety
13 | }
14 |
15 | // NewSafeWriter creates and returns a new SafeWriter
16 | func NewSafeWriter(writer io.Writer) (*SafeWriter, error) {
17 | // Check if the provided writer is nil
18 | if writer == nil {
19 | return nil, errors.New("writer is nil")
20 | }
21 |
22 | safeWriter := &SafeWriter{
23 | writer: writer,
24 | mutex: &sync.Mutex{},
25 | }
26 | return safeWriter, nil
27 | }
28 |
29 | // Write implements the io.Writer interface in a thread-safe manner
30 | func (sw *SafeWriter) Write(p []byte) (n int, err error) {
31 | sw.mutex.Lock()
32 | defer sw.mutex.Unlock()
33 |
34 | if sw.writer == nil {
35 | return 0, io.ErrClosedPipe
36 | }
37 | return sw.writer.Write(p)
38 | }
39 |
--------------------------------------------------------------------------------
/io/io_test.go:
--------------------------------------------------------------------------------
1 | package ioutil
2 |
3 | import (
4 | "strings"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestSafeWriter(t *testing.T) {
11 | t.Run("success", func(t *testing.T) {
12 | var sb strings.Builder
13 | sw, err := NewSafeWriter(&sb)
14 | require.Nil(t, err)
15 | _, err = sw.Write([]byte("test"))
16 | require.Nil(t, err)
17 | require.Equal(t, "test", sb.String())
18 | })
19 |
20 | t.Run("failure", func(t *testing.T) {
21 | sw, err := NewSafeWriter(nil)
22 | require.NotNil(t, err)
23 | require.Nil(t, sw)
24 | })
25 | }
26 |
--------------------------------------------------------------------------------
/ip/README.md:
--------------------------------------------------------------------------------
1 | # iputil
2 | The package contains various helpers to interact with ips and cidrs
--------------------------------------------------------------------------------
/log/README.md:
--------------------------------------------------------------------------------
1 | # logutil
2 | The package contains helpers to interact with logs
3 |
--------------------------------------------------------------------------------
/log/logutil.go:
--------------------------------------------------------------------------------
1 | package logutil
2 |
3 | import (
4 | "io"
5 | "log"
6 | "os"
7 | )
8 |
9 | // DisableDefaultLogger disables the default logger.
10 | func DisableDefaultLogger() {
11 | log.SetFlags(0)
12 | log.SetOutput(io.Discard)
13 | }
14 |
15 | // EnableDefaultLogger enables the default logger.
16 | func EnableDefaultLogger() {
17 | log.SetFlags(log.LstdFlags)
18 | log.SetOutput(os.Stderr)
19 | }
20 |
--------------------------------------------------------------------------------
/log/logutil_test.go:
--------------------------------------------------------------------------------
1 | package logutil
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "log"
7 | "os"
8 | "testing"
9 |
10 | "github.com/stretchr/testify/require"
11 | )
12 |
13 | func TestDisableDefaultLogger(t *testing.T) {
14 | msg := "sample test"
15 | buf := new(bytes.Buffer)
16 | log.SetOutput(buf)
17 | DisableDefaultLogger()
18 | log.Print(msg)
19 | require.Equal(t, "", buf.String())
20 | }
21 |
22 | func TestEnableDefaultLogger(t *testing.T) {
23 | msg := "sample test"
24 | buf := new(bytes.Buffer)
25 | var stderr = *os.Stderr
26 | r, w, _ := os.Pipe()
27 | os.Stderr = w
28 | exit := make(chan bool)
29 | go func() {
30 | _, _ = io.Copy(buf, r)
31 | exit <- true
32 | }()
33 | EnableDefaultLogger()
34 | log.Print(msg)
35 | w.Close()
36 | <-exit
37 | os.Stderr = &stderr
38 | require.Contains(t, buf.String(), msg)
39 | }
40 |
--------------------------------------------------------------------------------
/maps/README.md:
--------------------------------------------------------------------------------
1 | # mapsutil
2 | The package contains various helpers to interact with maps
--------------------------------------------------------------------------------
/maps/generic_map.go:
--------------------------------------------------------------------------------
1 | package mapsutil
2 |
3 | import "golang.org/x/exp/maps"
4 |
5 | // Map wraps a generic map type
6 | type Map[K, V comparable] map[K]V
7 |
8 | // Has checks if the current map has the provided key
9 | func (m Map[K, V]) Has(key K) bool {
10 | _, ok := m[key]
11 | return ok
12 | }
13 |
14 | // GetKeys from the map as a slice
15 | func (m Map[K, V]) GetKeys(keys ...K) []V {
16 | values := make([]V, len(keys))
17 | for i, key := range keys {
18 | values[i] = m[key]
19 | }
20 | return values
21 | }
22 |
23 | // GetOrDefault the provided key or default to the provided value
24 | func (m Map[K, V]) GetOrDefault(key K, defaultValue V) V {
25 | if v, ok := m[key]; ok {
26 | return v
27 | }
28 | return defaultValue
29 | }
30 |
31 | // Get returns the value for the provided key
32 | func (m Map[K, V]) Get(key K) (V, bool) {
33 | val, ok := m[key]
34 | return val, ok
35 | }
36 |
37 | // Merge the current map with the provided one
38 | func (m Map[K, V]) Merge(n map[K]V) {
39 | for k, v := range n {
40 | m[k] = v
41 | }
42 | }
43 |
44 | // GetKeyWithValue returns the first key having value
45 | func (m Map[K, V]) GetKeyWithValue(value V) (K, bool) {
46 | var zero K
47 | for k, v := range m {
48 | if v == value {
49 | return k, true
50 | }
51 | }
52 |
53 | return zero, false
54 | }
55 |
56 | // IsEmpty checks if the current map is empty
57 | func (m Map[K, V]) IsEmpty() bool {
58 | return len(m) == 0
59 | }
60 |
61 | // Clone the current map
62 | func (m Map[K, V]) Clone() Map[K, V] {
63 | return maps.Clone(m)
64 | }
65 |
66 | // Set the provided key with the provided value
67 | func (m Map[K, V]) Set(key K, value V) {
68 | m[key] = value
69 | }
70 |
71 | // Clear the map
72 | func (m Map[K, V]) Clear() bool {
73 | maps.Clear(m)
74 | return m.IsEmpty()
75 | }
76 |
--------------------------------------------------------------------------------
/maps/generic_map_test.go:
--------------------------------------------------------------------------------
1 | package mapsutil
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestMapHas(t *testing.T) {
11 | m := Map[string, int]{"foo": 1, "bar": 2}
12 | testCases := []struct {
13 | key string
14 | expected bool
15 | }{
16 | {"foo", true},
17 | {"baz", false},
18 | }
19 | for _, tc := range testCases {
20 | actual := m.Has(tc.key)
21 | if actual != tc.expected {
22 | t.Errorf("Has(%q) = %v, expected %v", tc.key, actual, tc.expected)
23 | }
24 | }
25 | }
26 |
27 | func TestMapGetKeys(t *testing.T) {
28 | m := Map[string, int]{"foo": 1, "bar": 2}
29 | testCases := []struct {
30 | keys []string
31 | expected []int
32 | }{
33 | {[]string{"foo", "bar"}, []int{1, 2}},
34 | {[]string{"baz", "qux"}, []int{0, 0}},
35 | }
36 | for _, tc := range testCases {
37 | actual := m.GetKeys(tc.keys...)
38 | if !reflect.DeepEqual(actual, tc.expected) {
39 | t.Errorf("GetKeys(%v) = %v, expected %v", tc.keys, actual, tc.expected)
40 | }
41 | }
42 | }
43 |
44 | func TestMapGetOrDefault(t *testing.T) {
45 | m := Map[string, int]{"foo": 1, "bar": 2}
46 | testCases := []struct {
47 | key string
48 | defaultV int
49 | expected int
50 | }{
51 | {"foo", 0, 1},
52 | {"baz", 0, 0},
53 | }
54 | for _, tc := range testCases {
55 | actual := m.GetOrDefault(tc.key, tc.defaultV)
56 | if actual != tc.expected {
57 | t.Errorf("GetOrDefault(%q, %d) = %d, expected %d", tc.key, tc.defaultV, actual, tc.expected)
58 | }
59 | }
60 | }
61 |
62 | func TestMapMerge(t *testing.T) {
63 | m := Map[string, int]{"foo": 1, "bar": 2}
64 | n := map[string]int{"baz": 3, "qux": 4}
65 | m.Merge(n)
66 | expected := Map[string, int]{"foo": 1, "bar": 2, "baz": 3, "qux": 4}
67 | if !reflect.DeepEqual(m, expected) {
68 | t.Errorf("Merge(%v) = %v, expected %v", n, m, expected)
69 | }
70 | }
71 |
72 | func TestMap_GetKeyWithValue(t *testing.T) {
73 | type testCase[K, V comparable] struct {
74 | InputMap Map[K, V]
75 | Value V
76 | ExpectedKey K
77 | ExpectedOk bool
78 | }
79 |
80 | genericMap := Map[string, string]{"a": "a", "b": "b", "c": "c"}
81 |
82 | testCases := []testCase[string, string]{
83 | {
84 | InputMap: genericMap,
85 | Value: "b",
86 | ExpectedKey: "b",
87 | ExpectedOk: true,
88 | },
89 | {
90 | InputMap: genericMap,
91 | Value: "d",
92 | ExpectedKey: "",
93 | ExpectedOk: false,
94 | },
95 | {
96 | InputMap: genericMap,
97 | Value: "b",
98 | ExpectedKey: "b",
99 | ExpectedOk: true,
100 | },
101 | {
102 | InputMap: genericMap,
103 | Value: "d",
104 | ExpectedKey: "",
105 | ExpectedOk: false,
106 | },
107 | {
108 | InputMap: genericMap,
109 | Value: "value",
110 | ExpectedKey: "",
111 | ExpectedOk: false,
112 | },
113 | }
114 |
115 | for _, tc := range testCases {
116 | key, ok := tc.InputMap.GetKeyWithValue(tc.Value)
117 | require.Equal(t, tc.ExpectedKey, key)
118 | require.Equal(t, tc.ExpectedOk, ok)
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/memguardian/README.MD:
--------------------------------------------------------------------------------
1 | ## Mem Guardian Usage Guide
2 |
3 | ### Environment Variables
4 |
5 | - `MEMGUARDIAN`: Enable or disable memguardian. Set to 1 to enable
6 | - `MEMGUARDIAN_MAX_RAM_RATIO`: Maximum ram ratio from 1 to 100
7 | - `MEMGUARDIAN_MAX_RAM`: Maximum amount of RAM (in size units ex: 10gb)
8 | - `MEMGUARDIAN_INTERVAL`: detection interval (with unit ex: 30s)
9 |
10 |
11 |
12 | ## How to Use
13 |
14 | 1. Set the environment variables as per your requirements.
15 |
16 | ```bash
17 | export MEMGUARDIAN=1
18 | export MEMGUARDIAN_MAX_RAM_RATIO=75 # default
19 | export MEMGUARDIAN_MAX_RAM=6Gb # optional
20 | export MEMGUARDIAN_INTERVAL=30s # default
21 | ```
22 |
23 | 2. Run your Go application. The profiler will start automatically if MEMGUARDIAN is set to 1.
--------------------------------------------------------------------------------
/memguardian/doc.go:
--------------------------------------------------------------------------------
1 | // memguardian is a package that provides a simple RAM memory control mechanism
2 | // once activated it sets an internal atomic boolean when the RAM usage exceed in absolute
3 | // terms the warning ratio, for passive indirect check or invoke an optional callback for
4 | // reactive backpressure
5 | package memguardian
6 |
--------------------------------------------------------------------------------
/memguardian/memory.go:
--------------------------------------------------------------------------------
1 | package memguardian
2 |
3 | type SysInfo struct {
4 | Uptime int64
5 | totalRam uint64
6 | freeRam uint64
7 | SharedRam uint64
8 | BufferRam uint64
9 | TotalSwap uint64
10 | FreeSwap uint64
11 | Unit uint64
12 | usedPercent float64
13 | }
14 |
15 | func (si *SysInfo) TotalRam() uint64 {
16 | return uint64(si.totalRam) * uint64(si.Unit)
17 | }
18 |
19 | func (si *SysInfo) FreeRam() uint64 {
20 | return uint64(si.freeRam) * uint64(si.Unit)
21 | }
22 |
23 | func (si *SysInfo) UsedRam() uint64 {
24 | return si.TotalRam() - si.FreeRam()
25 | }
26 |
27 | func (si *SysInfo) UsedPercent() float64 {
28 | if si.usedPercent > 0 {
29 | return si.usedPercent
30 | }
31 |
32 | return 100 * float64((si.TotalRam()-si.FreeRam())*si.Unit) / float64(si.TotalRam())
33 | }
34 |
35 | func GetSysInfo() (*SysInfo, error) {
36 | return getSysInfo()
37 | }
38 |
--------------------------------------------------------------------------------
/memguardian/memory_linux.go:
--------------------------------------------------------------------------------
1 | //go:build linux
2 |
3 | package memguardian
4 |
5 | import "syscall"
6 |
7 | func getSysInfo() (*SysInfo, error) {
8 | var sysInfo syscall.Sysinfo_t
9 | err := syscall.Sysinfo(&sysInfo)
10 | if err != nil {
11 | return nil, err
12 | }
13 |
14 | si := &SysInfo{
15 | Uptime: int64(sysInfo.Uptime),
16 | totalRam: uint64(sysInfo.Totalram),
17 | freeRam: uint64(sysInfo.Freeram),
18 | SharedRam: uint64(sysInfo.Freeram),
19 | BufferRam: uint64(sysInfo.Bufferram),
20 | TotalSwap: uint64(sysInfo.Totalswap),
21 | FreeSwap: uint64(sysInfo.Freeswap),
22 | Unit: uint64(sysInfo.Unit),
23 | }
24 |
25 | return si, nil
26 | }
27 |
--------------------------------------------------------------------------------
/memguardian/memory_others.go:
--------------------------------------------------------------------------------
1 | //go:build !linux
2 |
3 | package memguardian
4 |
5 | import "github.com/shirou/gopsutil/mem"
6 |
7 | // TODO: replace with native syscall
8 | func getSysInfo() (*SysInfo, error) {
9 | vms, err := mem.VirtualMemory()
10 | if err != nil {
11 | return nil, err
12 | }
13 | si := &SysInfo{
14 | totalRam: vms.Total,
15 | freeRam: vms.Free,
16 | SharedRam: vms.Shared,
17 | TotalSwap: vms.SwapTotal,
18 | FreeSwap: vms.SwapFree,
19 | usedPercent: vms.UsedPercent,
20 | }
21 |
22 | return si, nil
23 | }
24 |
--------------------------------------------------------------------------------
/memoize/cmd/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 |
6 | "github.com/projectdiscovery/utils/memoize"
7 | )
8 |
9 | func main() {
10 | out, err := memoize.File(memoize.PackageTemplate, "../tests/test.go", "test")
11 | if err != nil {
12 | panic(err)
13 | }
14 | log.Println(string(out))
15 | }
16 |
--------------------------------------------------------------------------------
/memoize/gen/generic/memoize.go:
--------------------------------------------------------------------------------
1 | // this small cli tool is specific for those functions with arbitrary parameters and with result-error tuple as return values
2 | // func(x,y) => result, error
3 | // it works by creating a new memoized version of the functions in the same path as memo.original.file.go
4 | // some parts are specific for nuclei and hardcoded within the template
5 | package main
6 |
7 | import (
8 | "flag"
9 | "io/fs"
10 | "log"
11 | "os"
12 | "path/filepath"
13 |
14 | "github.com/projectdiscovery/utils/memoize"
15 | stringsutil "github.com/projectdiscovery/utils/strings"
16 | )
17 |
18 | var (
19 | src = flag.String("src", "", "go sources")
20 | )
21 |
22 | func main() {
23 | flag.Parse()
24 |
25 | err := filepath.WalkDir(*src, walkDir)
26 | if err != nil {
27 | log.Fatal(err)
28 | }
29 | }
30 |
31 | func walkDir(path string, d fs.DirEntry, err error) error {
32 | if d.IsDir() {
33 | return nil
34 | }
35 |
36 | if err != nil {
37 | return err
38 | }
39 |
40 | ext := filepath.Ext(path)
41 | base := filepath.Base(path)
42 |
43 | if !stringsutil.EqualFoldAny(ext, ".go") {
44 | return nil
45 | }
46 |
47 | basePath := filepath.Dir(path)
48 | outPath := filepath.Join(basePath, "memo."+base)
49 |
50 | // filename := filepath.Base(path)
51 | data, err := os.ReadFile(path)
52 | if err != nil {
53 | return err
54 | }
55 | if !stringsutil.ContainsAnyI(string(data), "@memo") {
56 | return nil
57 | }
58 | out, err := memoize.Src(memoize.PackageTemplate, path, data, "test")
59 | if err != nil {
60 | return err
61 | }
62 |
63 | if err := os.WriteFile(outPath, out, os.ModePerm); err != nil {
64 | return err
65 | }
66 |
67 | return nil
68 | }
69 |
--------------------------------------------------------------------------------
/memoize/memoize_test.go:
--------------------------------------------------------------------------------
1 | package memoize
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestMemo(t *testing.T) {
11 | testingFunc := func() (interface{}, error) {
12 | time.Sleep(10 * time.Second)
13 | return "b", nil
14 | }
15 |
16 | m, err := New(WithMaxSize(5))
17 | require.Nil(t, err)
18 | start := time.Now()
19 | _, _, _ = m.Do("test", testingFunc)
20 | _, _, _ = m.Do("test", testingFunc)
21 | require.True(t, time.Since(start) < time.Duration(15*time.Second))
22 | }
23 |
24 | func TestSrc(t *testing.T) {
25 | out, err := File(PackageTemplate, "tests/test.go", "test")
26 | require.Nil(t, err)
27 | require.True(t, len(out) > 0)
28 | }
29 |
--------------------------------------------------------------------------------
/memoize/package_template.tpl:
--------------------------------------------------------------------------------
1 | package {{.PackageName}}
2 |
3 | import (
4 | "github.com/projectdiscovery/utils/memoize"
5 |
6 | {{range .Imports}}
7 | {{.Name}} {{.Path}}
8 | {{end}}
9 | )
10 |
11 | {{range .Functions}}
12 | {{ if .WantReturn }}
13 | type {{ .ResultStructType }} struct {
14 | {{ range .Results }}
15 | {{ .ResultName }} {{ .Type }}
16 | {{ end }}
17 | }
18 | {{ end }}
19 | var (
20 | {{ if .WantSyncOnce }}
21 | {{ .SyncOnceVarName }} sync.Once
22 |
23 | {{ if .WantReturn }}
24 | {{ .ResultStructVarName }} {{ .ResultStructType }}
25 | {{ end }}
26 |
27 | {{ end }}
28 | )
29 |
30 | {{ .Signature }} {
31 | {{ if .WantSyncOnce }}
32 |
33 | {{ .SyncOnceVarName }}.Do(func() {
34 | {{ if .WantReturn }}
35 | {{ .ResultStructFields }} = {{.SourcePackage}}.{{.Name}}()
36 | {{ else }}
37 | {{.SourcePackage}}.{{.Name}}()
38 | {{ end }}
39 | })
40 |
41 | {{ if .WantReturn }}
42 | return {{ .ResultStructFields }}
43 | {{ end }}
44 |
45 | {{ else }}
46 |
47 | h := hash("{{.Name}}", {{.ParamsNames}})
48 | v, _, _ := cache.Do(h, func() (interface{}, error) {
49 | {{ if .WantReturn }}
50 | {{.ResultStructVarName}} := &{{.ResultStructType}}{}
51 | {{ .ResultStructFields }} = {{.SourcePackage}}.{{.Name}}({{.ParamsNames}})
52 | return {{.ResultStructVarName}}, nil
53 | {{else}}
54 | {{.SourcePackage}}.{{.Name}}({{.ParamsNames}})
55 | return nil, nil
56 | {{end}}
57 | })
58 | {{ if .WantReturn }}
59 | {{.ResultStructVarName}} := v.(*{{.ResultStructType}})
60 | {{else}}
61 | _ = v
62 | {{end}}
63 |
64 | {{ if .WantReturn }}
65 | return {{ .ResultStructFields }}
66 | {{ end }}
67 |
68 | {{ end }}
69 | }
70 | {{end}}
71 |
72 | func hash(functionName string, args ...any) string {
73 | var b bytes.Buffer
74 | b.WriteString(functionName + ":")
75 | for _, arg := range args {
76 | b.WriteString(fmt.Sprint(arg))
77 | }
78 | h := sha256.Sum256(b.Bytes())
79 | return hex.EncodeToString(h[:])
80 | }
81 |
82 | var cache *memoize.Memoizer
83 |
84 | func init() {
85 | cache, _ = memoize.New(memoize.WithMaxSize(1000))
86 | }
87 |
88 |
--------------------------------------------------------------------------------
/memoize/templates.go:
--------------------------------------------------------------------------------
1 | package memoize
2 |
3 | import _ "embed"
4 |
5 | //go:embed package_template.tpl
6 | var PackageTemplate string
7 |
--------------------------------------------------------------------------------
/memoize/tests/test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "errors"
5 | "time"
6 | )
7 |
8 | // @memo
9 | func Test(a string, b string) string {
10 | return "something"
11 | }
12 |
13 | // @memo
14 | func TestWithArgs(a string, b string) {
15 |
16 | }
17 |
18 | // @memo
19 | func TestNothing() {
20 | time.Sleep(time.Second)
21 | }
22 |
23 | // @memo
24 | func TestWithOneReturn() string {
25 | return "a"
26 | }
27 |
28 | // @memo
29 | func TestWithMultipleReturnValues() (string, int, error) {
30 | return "a", 2, errors.New("test")
31 | }
32 |
--------------------------------------------------------------------------------
/ml/metrics/classification_report.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | func (cm *ConfusionMatrix) PrintClassificationReport() string {
9 | var s strings.Builder
10 | s.WriteString(fmt.Sprintf("%30s\n", "Classification Report"))
11 | s.WriteString(fmt.Sprintln())
12 |
13 | s.WriteString(fmt.Sprintf("\n%-15s %-10s %-10s %-10s %-10s\n", "", "precision", "recall", "f1-score", "support"))
14 |
15 | totals := map[string]float64{"true": 0, "predicted": 0, "correct": 0}
16 | macroAvg := map[string]float64{"precision": 0, "recall": 0, "f1-score": 0}
17 |
18 | for i, label := range cm.labels {
19 | truePos := cm.matrix[i][i]
20 | falsePos, falseNeg := 0, 0
21 | for j := 0; j < len(cm.labels); j++ {
22 | if i != j {
23 | falsePos += cm.matrix[j][i]
24 | falseNeg += cm.matrix[i][j]
25 | }
26 | }
27 |
28 | precision := float64(truePos) / float64(truePos+falsePos)
29 | recall := float64(truePos) / float64(truePos+falseNeg)
30 | f1Score := 2 * precision * recall / (precision + recall)
31 | support := truePos + falseNeg
32 |
33 | fmt.Printf("%-15s %-10.2f %-10.2f %-10.2f %-10d\n", label, precision, recall, f1Score, support)
34 |
35 | totals["true"] += float64(support)
36 | totals["predicted"] += float64(truePos + falsePos)
37 | totals["correct"] += float64(truePos)
38 |
39 | macroAvg["precision"] += precision
40 | macroAvg["recall"] += recall
41 | macroAvg["f1-score"] += f1Score
42 | }
43 |
44 | accuracy := totals["correct"] / totals["true"]
45 | s.WriteString(fmt.Sprintf("\n%-26s %-10s %-10.2f %-10d", "accuracy", "", accuracy, int(totals["true"])))
46 |
47 | s.WriteString(fmt.Sprintf("\n%-15s %-10.2f %-10.2f %-10.2f %-10d\n", "macro avg",
48 | macroAvg["precision"]/float64(len(cm.labels)),
49 | macroAvg["recall"]/float64(len(cm.labels)),
50 | macroAvg["f1-score"]/float64(len(cm.labels)),
51 | int(totals["true"])))
52 |
53 | precisionWeightedAvg := totals["correct"] / totals["predicted"]
54 | recallWeightedAvg := totals["correct"] / totals["true"]
55 | f1ScoreWeightedAvg := 2 * precisionWeightedAvg * recallWeightedAvg / (precisionWeightedAvg + recallWeightedAvg)
56 |
57 | s.WriteString(fmt.Sprintf("%-15s %-10.2f %-10.2f %-10.2f %-10d\n", "weighted avg",
58 | precisionWeightedAvg, recallWeightedAvg, f1ScoreWeightedAvg, int(totals["true"])))
59 |
60 | s.WriteString(fmt.Sprintln())
61 |
62 | return s.String()
63 | }
64 |
--------------------------------------------------------------------------------
/ml/metrics/confusion_matrix.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | type ConfusionMatrix struct {
9 | matrix [][]int
10 | labels []string
11 | }
12 |
13 | func NewConfusionMatrix(actual, predicted []string, labels []string) *ConfusionMatrix {
14 | n := len(labels)
15 | matrix := make([][]int, n)
16 | for i := range matrix {
17 | matrix[i] = make([]int, n)
18 | }
19 |
20 | labelIndices := make(map[string]int)
21 | for i, label := range labels {
22 | labelIndices[label] = i
23 | }
24 |
25 | for i := range actual {
26 | matrix[labelIndices[actual[i]]][labelIndices[predicted[i]]]++
27 | }
28 |
29 | return &ConfusionMatrix{
30 | matrix: matrix,
31 | labels: labels,
32 | }
33 | }
34 |
35 | func (cm *ConfusionMatrix) PrintConfusionMatrix() string {
36 | var s strings.Builder
37 |
38 | s.WriteString(fmt.Sprintf("%30s\n", "Confusion Matrix"))
39 | s.WriteString(fmt.Sprintln())
40 | // Print header
41 | s.WriteString(fmt.Sprintf("%-15s", ""))
42 | for _, label := range cm.labels {
43 | s.WriteString(fmt.Sprintf("%-15s", label))
44 | }
45 | s.WriteString(fmt.Sprintln())
46 |
47 | // Print rows
48 | for i, row := range cm.matrix {
49 | s.WriteString(fmt.Sprintf("%-15s", cm.labels[i]))
50 | for _, value := range row {
51 | s.WriteString(fmt.Sprintf("%-15d", value))
52 | }
53 | s.WriteString(fmt.Sprintln())
54 | }
55 | s.WriteString(fmt.Sprintln())
56 |
57 | return s.String()
58 | }
59 |
--------------------------------------------------------------------------------
/ml/model_selection/model_selection.go:
--------------------------------------------------------------------------------
1 | package modelselection
2 |
3 | import (
4 | "math/rand"
5 | )
6 |
7 | func TrainTestSplit(dataset []interface{}, testSize float64) (train, test []interface{}) {
8 | for _, data := range dataset {
9 | if rand.Float64() > testSize {
10 | train = append(train, data)
11 | } else {
12 | test = append(test, data)
13 | }
14 | }
15 | return train, test
16 | }
17 |
--------------------------------------------------------------------------------
/ml/naive_bayes/naive_bayes_classifier_test.go:
--------------------------------------------------------------------------------
1 | package naive_bayes
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestNaiveBayesClassifier(t *testing.T) {
10 | // Create a new Naive Bayes Classifier
11 | threshold := 1.1
12 | nb := New(threshold)
13 |
14 | // Create a new training set
15 | trainingSet := map[string][]string{
16 | "Baseball": {
17 | "Pitcher",
18 | "Shortstop",
19 | "Outfield",
20 | },
21 | "Basketball": {
22 | "Point Guard",
23 | "Shooting Guard",
24 | "Small Forward",
25 | "Power Forward",
26 | "Center",
27 | },
28 | "Soccer": {
29 | "Goalkeeper",
30 | "Defender",
31 | "Midfielder",
32 | "Forward",
33 | },
34 | }
35 |
36 | // Train the classifier
37 | nb.Fit(trainingSet)
38 |
39 | //then
40 | assert.Equal(t, nb.Classify("Point guard"), "Basketball")
41 | }
42 |
--------------------------------------------------------------------------------
/ml/types.go:
--------------------------------------------------------------------------------
1 | package mlutils
2 |
3 | type LabeledDocument struct {
4 | Label string
5 | Document string
6 | }
7 |
--------------------------------------------------------------------------------
/net/net.go:
--------------------------------------------------------------------------------
1 | package netutil
2 |
3 | import (
4 | "errors"
5 | "net"
6 | )
7 |
8 | var ErrMissingPort = errors.New("missing port")
9 |
10 | // TryJoinHostPort joins host and port. If port is empty, it returns host and an error.
11 | func TryJoinHostPort(host, port string) (string, error) {
12 | if host == "" {
13 | return "", &net.AddrError{Err: "missing host", Addr: host}
14 | }
15 |
16 | if port == "" {
17 | return host, ErrMissingPort
18 | }
19 |
20 | return net.JoinHostPort(host, port), nil
21 | }
22 |
--------------------------------------------------------------------------------
/net/net_test.go:
--------------------------------------------------------------------------------
1 | package netutil
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestTryJoinHostPort(t *testing.T) {
8 | tests := []struct {
9 | name string
10 | host string
11 | port string
12 | want string
13 | wantErr bool
14 | }{
15 | {
16 | name: "both host and port provided",
17 | host: "localhost",
18 | port: "8080",
19 | want: "localhost:8080",
20 | wantErr: false,
21 | },
22 | {
23 | name: "empty host",
24 | host: "",
25 | port: "8080",
26 | want: "",
27 | wantErr: true,
28 | },
29 | {
30 | name: "empty port",
31 | host: "localhost",
32 | port: "",
33 | want: "localhost",
34 | wantErr: true,
35 | },
36 | {
37 | name: "both host and port empty",
38 | host: "",
39 | port: "",
40 | want: "",
41 | wantErr: true,
42 | },
43 | }
44 |
45 | for _, tt := range tests {
46 | t.Run(tt.name, func(t *testing.T) {
47 | got, err := TryJoinHostPort(tt.host, tt.port)
48 | if (err != nil) != tt.wantErr {
49 | t.Errorf("TryJoinHostPort() error = %v, wantErr %v", err, tt.wantErr)
50 | return
51 | }
52 | if got != tt.want {
53 | t.Errorf("TryJoinHostPort() = %v, want %v", got, tt.want)
54 | }
55 | })
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/os/arch.go:
--------------------------------------------------------------------------------
1 | package osutils
2 |
3 | import "runtime"
4 |
5 | type ArchType uint8
6 |
7 | const (
8 | I386 ArchType = iota
9 | Amd64
10 | Amd64p32
11 | Arm
12 | Armbe
13 | Arm64
14 | Arm64be
15 | Loong64
16 | Mips
17 | Mipsle
18 | Mips64
19 | Mips64le
20 | Mips64p32
21 | Mips64p32le
22 | Ppc
23 | Ppc64
24 | Ppc64le
25 | Riscv
26 | Riscv64
27 | S390
28 | S390x
29 | Sparc
30 | Sparc64
31 | Wasm
32 | UknownArch
33 | )
34 |
35 | var Arch ArchType
36 |
37 | func init() {
38 | switch {
39 | case Is386():
40 | Arch = I386
41 | case IsAmd64():
42 | Arch = Amd64
43 | case IsARM():
44 | Arch = Arm
45 | case IsARM64():
46 | Arch = Arm64
47 | case IsWasm():
48 | Arch = Wasm
49 | default:
50 | Arch = UknownArch
51 | }
52 | }
53 |
54 | func Is386() bool {
55 | return runtime.GOARCH == "386"
56 | }
57 |
58 | func IsAmd64() bool {
59 | return runtime.GOARCH == "amd64"
60 | }
61 |
62 | func IsARM() bool {
63 | return runtime.GOARCH == "arm"
64 | }
65 |
66 | func IsARM64() bool {
67 | return runtime.GOARCH == "arm64"
68 | }
69 |
70 | func IsWasm() bool {
71 | return runtime.GOARCH == "wasm"
72 | }
73 |
--------------------------------------------------------------------------------
/os/os.go:
--------------------------------------------------------------------------------
1 | package osutils
2 |
3 | import "runtime"
4 |
5 | type OsType uint8
6 |
7 | const (
8 | Darwin OsType = iota
9 | windows
10 | Linux
11 | Android
12 | IOS
13 | FreeBSD
14 | OpenBSD
15 | JS
16 | Solaris
17 | UnknownOS
18 | )
19 |
20 | var OS OsType
21 |
22 | func init() {
23 | switch {
24 | case IsOSX():
25 | OS = Darwin
26 | case IsLinux():
27 | OS = Linux
28 | case IsWindows():
29 | OS = windows
30 | case IsAndroid():
31 | OS = Android
32 | case IsIOS():
33 | OS = IOS
34 | case IsJS():
35 | OS = JS
36 | case IsFreeBSD():
37 | OS = FreeBSD
38 | case IsOpenBSD():
39 | OS = OpenBSD
40 | case IsSolaris():
41 | OS = Solaris
42 | default:
43 | OS = UnknownOS
44 | }
45 | }
46 |
47 | func IsOSX() bool {
48 | return runtime.GOOS == "darwin"
49 | }
50 |
51 | func IsLinux() bool {
52 | return runtime.GOOS == "linux"
53 | }
54 |
55 | func IsWindows() bool {
56 | return runtime.GOOS == "windows"
57 | }
58 |
59 | func IsAndroid() bool {
60 | return runtime.GOOS == "android"
61 | }
62 |
63 | func IsIOS() bool {
64 | return runtime.GOOS == "ios"
65 | }
66 |
67 | func IsFreeBSD() bool {
68 | return runtime.GOOS == "freebsd"
69 | }
70 |
71 | func IsOpenBSD() bool {
72 | return runtime.GOOS == "openbsd"
73 | }
74 |
75 | func IsJS() bool {
76 | return runtime.GOOS == "js"
77 | }
78 |
79 | func IsSolaris() bool {
80 | return runtime.GOOS == "solaris"
81 | }
82 |
--------------------------------------------------------------------------------
/patterns/doc.go:
--------------------------------------------------------------------------------
1 | // package patterns contains various common patterns
2 | // some regexps were extended from https://github.com/asaskevich/govalidator
3 | package patterns
4 |
--------------------------------------------------------------------------------
/permission/README.md:
--------------------------------------------------------------------------------
1 | # permissionutils
2 | The package contains various helpers about permissions/privileges
3 |
--------------------------------------------------------------------------------
/permission/error.go:
--------------------------------------------------------------------------------
1 | package permissionutil
2 |
3 | import "errors"
4 |
5 | var ErrNotImplemented = errors.New("not implemented")
6 |
--------------------------------------------------------------------------------
/permission/permission.go:
--------------------------------------------------------------------------------
1 | package permissionutil
2 |
3 | var (
4 | IsRoot bool
5 | HasCapNetRaw bool
6 | )
7 |
8 | func init() {
9 | IsRoot, _ = checkCurrentUserRoot()
10 | HasCapNetRaw, _ = checkCurrentUserCapNetRaw()
11 | }
12 |
--------------------------------------------------------------------------------
/permission/permission_file.go:
--------------------------------------------------------------------------------
1 | package permissionutil
2 |
3 | import "os"
4 |
5 | // Set permissions for a file using file.Chmod(os.FileMode())
6 | // Example: file.Chmod(os.FileMode(AllReadWriteExecute))
7 | // If you are trying to set permissions using os.OpenFile then permissions get filtered out by the umask.
8 | // these permissions are 'filtered' by whatever umask has been set.
9 | // https://stackoverflow.com/questions/66097279/why-will-os-openfile-not-create-a-777-file
10 |
11 | const (
12 | os_read = 04
13 | os_write = 02
14 | os_ex = 01
15 | os_user_shift = 6
16 | os_group_shift = 3
17 | os_other_shift = 0
18 |
19 | // User Read Write Execute Permission
20 | UserRead = os_read << os_user_shift
21 | UserWrite = os_write << os_user_shift
22 | UserExecute = os_ex << os_user_shift
23 | UserReadWrite = UserRead | UserWrite
24 | UserReadWriteExecute = UserReadWrite | UserExecute
25 |
26 | // Group Read Write Execute Permission
27 | GroupRead = os_read << os_group_shift
28 | GroupWrite = os_write << os_group_shift
29 | GroupExecute = os_ex << os_group_shift
30 | GroupReadWrite = GroupRead | GroupWrite
31 | GroupReadWriteExecute = GroupReadWrite | GroupExecute
32 |
33 | // Other Read Write Execute Permission
34 | OtherRead = os_read << os_other_shift
35 | OtherWrite = os_write << os_other_shift
36 | OtherExecute = os_ex << os_other_shift
37 | OtherReadWrite = OtherRead | OtherWrite
38 | OtherReadWriteExecute = OtherReadWrite | OtherExecute
39 |
40 | // All Read Write Execute Permission
41 | AllRead = UserRead | GroupRead | OtherRead
42 | AllWrite = UserWrite | GroupWrite | OtherWrite
43 | AllExecute = UserExecute | GroupExecute | OtherExecute
44 | AllReadWrite = AllRead | AllWrite
45 | AllReadWriteExecute = AllReadWrite | AllExecute
46 |
47 | // Default File/Folder Permissions
48 | ConfigFolderPermission = UserReadWriteExecute
49 | ConfigFilePermission = UserReadWrite
50 | BinaryPermission = UserRead | UserExecute
51 | TempFilePermission = UserReadWrite
52 | )
53 |
54 | // UpdateFilePerm modifies the permissions of the given file.
55 | // Returns an error if the file permissions could not be updated.
56 | func UpdateFilePerm(filename string, perm int) error {
57 | newPerms := os.FileMode(perm)
58 | return os.Chmod(filename, newPerms)
59 | }
60 |
--------------------------------------------------------------------------------
/permission/permission_linux.go:
--------------------------------------------------------------------------------
1 | //go:build linux && !(armv7l || armv8l) && !android
2 |
3 | package permissionutil
4 |
5 | import (
6 | "errors"
7 | "os"
8 | "runtime"
9 |
10 | raceutil "github.com/projectdiscovery/utils/race"
11 |
12 | "golang.org/x/sys/unix"
13 | )
14 |
15 | // checkCurrentUserRoot checks if the current user is root
16 | func checkCurrentUserRoot() (bool, error) {
17 | return os.Geteuid() == 0, nil
18 | }
19 |
20 | // checkCurrentUserCapNetRaw checks if the current user has the CAP_NET_RAW capability
21 | func checkCurrentUserCapNetRaw() (bool, error) {
22 | if raceutil.Enabled {
23 | return false, errors.New("race detector enabled")
24 | }
25 | // runtime.LockOSThread interferes with race detection
26 | header := unix.CapUserHeader{
27 | Version: unix.LINUX_CAPABILITY_VERSION_3,
28 | Pid: int32(os.Getpid()),
29 | }
30 | data := [2]unix.CapUserData{}
31 | runtime.LockOSThread()
32 | defer runtime.UnlockOSThread()
33 |
34 | err := unix.Capget(&header, &data[0])
35 | if err != nil {
36 | return false, err
37 | }
38 | data[0].Inheritable = (1 << unix.CAP_NET_RAW)
39 | if err = unix.Capset(&header, &data[0]); err != nil {
40 | return false, err
41 | }
42 | return true, nil
43 | }
44 |
--------------------------------------------------------------------------------
/permission/permission_other.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || freebsd || netbsd || dragonfly || openbsd || solaris || android || ios || (linux && armv7l) || (linux && armv8l)
2 |
3 | package permissionutil
4 |
5 | import (
6 | "os"
7 | )
8 |
9 | // checkCurrentUserRoot checks if the current user is root
10 | func checkCurrentUserRoot() (bool, error) {
11 | return os.Geteuid() == 0, nil
12 | }
13 |
14 | // checkCurrentUserCapNetRaw checks if the current user has the CAP_NET_RAW capability
15 | func checkCurrentUserCapNetRaw() (bool, error) {
16 | return false, ErrNotImplemented
17 | }
18 |
--------------------------------------------------------------------------------
/permission/permission_test.go:
--------------------------------------------------------------------------------
1 | //go:build windows || linux
2 |
3 | package permissionutil
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 | )
10 |
11 | func TestIsRoot(t *testing.T) {
12 | isRoot, err := checkCurrentUserRoot()
13 | require.Nil(t, err)
14 | require.NotNil(t, isRoot)
15 | }
16 |
--------------------------------------------------------------------------------
/permission/permission_win.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package permissionutil
4 |
5 | import (
6 | "golang.org/x/sys/windows"
7 | )
8 |
9 | // checkCurrentUserRoot on Windows
10 | // from https://github.com/golang/go/issues/28804#issuecomment-505326268
11 | func checkCurrentUserRoot() (bool, error) {
12 | var sid *windows.SID
13 |
14 | // Although this looks scary, it is directly copied from the
15 | // official windows documentation. The Go API for this is a
16 | // direct wrap around the official C++ API.
17 | // See https://docs.microsoft.com/en-us/windows/desktop/api/securitybaseapi/nf-securitybaseapi-checktokenmembership
18 | err := windows.AllocateAndInitializeSid(
19 | &windows.SECURITY_NT_AUTHORITY,
20 | 2,
21 | windows.SECURITY_BUILTIN_DOMAIN_RID,
22 | windows.DOMAIN_ALIAS_RID_ADMINS,
23 | 0, 0, 0, 0, 0, 0,
24 | &sid)
25 | if err != nil {
26 | return false, err
27 | }
28 |
29 | defer func() { _ = windows.FreeSid(sid) }()
30 |
31 | // This appears to cast a null pointer so I'm not sure why this
32 | // works, but this guy says it does and it Works for Me™:
33 | // https://github.com/golang/go/issues/28804#issuecomment-438838144
34 | token := windows.Token(0)
35 |
36 | member, err := token.IsMember(sid)
37 | if err != nil {
38 | return false, err
39 | }
40 |
41 | // Also note that an admin is _not_ necessarily considered
42 | // elevated.
43 | // For elevation see https://github.com/mozey/run-as-admin
44 | return token.IsElevated() || member, nil
45 | }
46 |
47 | // checkCurrentUserCapNetRaw on windows is not implemented
48 | func checkCurrentUserCapNetRaw() (bool, error) {
49 | return false, ErrNotImplemented
50 | }
51 |
--------------------------------------------------------------------------------
/ports/ports.go:
--------------------------------------------------------------------------------
1 | package ports
2 |
3 | import (
4 | "strconv"
5 | )
6 |
7 | // IsValid checks if a port is valid
8 | func IsValid(v interface{}) bool {
9 | switch p := v.(type) {
10 | case string:
11 | return IsValidWithString(p)
12 | case int:
13 | return IsValidWithInt(p)
14 | }
15 | return false
16 | }
17 |
18 | // IsValidWithString checks if a string port is valid
19 | func IsValidWithString(p string) bool {
20 | port, err := strconv.Atoi(p)
21 | return err == nil && IsValidWithInt(port)
22 | }
23 |
24 | // IsValidWithInt checks if an int port is valid
25 | func IsValidWithInt(port int) bool {
26 | return port >= 1 && port <= 65535
27 | }
28 |
--------------------------------------------------------------------------------
/ports/ports_test.go:
--------------------------------------------------------------------------------
1 | package ports
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | )
8 |
9 | func TestIsValid(t *testing.T) {
10 | t.Run("valid-ports-strings", func(t *testing.T) {
11 | ports := []interface{}{"1", "10000", "65535", 1, 10000, 65535}
12 | for _, port := range ports {
13 | require.True(t, IsValid(port))
14 | }
15 | })
16 | t.Run("invalid-ports", func(t *testing.T) {
17 | ports := []interface{}{"", "-1", "0", "65536", 0, -1, 65536, 2.1, "a"}
18 | for _, port := range ports {
19 | require.False(t, IsValid(port))
20 | }
21 | })
22 | }
23 |
--------------------------------------------------------------------------------
/pprof/README.md:
--------------------------------------------------------------------------------
1 | ## PProfiling Usage Guide
2 |
3 | Two types of profiling are supported:
4 |
5 | 1. **pprof**: Standard go profiling writing to files in a directory.
6 | 2. **server**: Profiling server listening on a port with pprof and fgprof endpoints.
7 |
8 | ### pprof
9 |
10 | #### Environment Variables
11 |
12 | - `PPROF`: Enable or disable profiling. Set to 1 to enable.
13 | - `MEM_PROFILE_DIR`: Directory to write memory profiles to.
14 | - `CPU_PROFILE_DIR`: Directory to write CPU profiles to.
15 | - `PPROF_TIME`: Polling time for CPU and memory profiles (with unit ex: 10s).
16 | - `MEM_PROFILE_RATE`: Memory profiling rate (default 4096).
17 |
18 |
19 | ### How to Use
20 |
21 | 1. Set the environment variables as per your requirements.
22 |
23 | ```bash
24 | export PPROF=1
25 | export MEM_PROFILE_DIR=/path/to/memprofile
26 | export CPU_PROFILE_DIR=/path/to/cpuprofile
27 | export PPROF_TIME=10s
28 | export MEM_PROFILE_RATE=4096
29 | ```
30 |
31 | 2. Run your Go application. The profiler will start automatically if PPROF is set to 1.
32 |
33 | **Output**
34 |
35 | - Memory profiles will be written to the directory specified by MEM_PROFILE_DIR.
36 | - CPU profiles will be written to the directory specified by CPU_PROFILE_DIR.
37 | - Profiles will be written at intervals specified by PPROF_TIME.
38 | - Memory profiling rate is controlled by MEM_PROFILE_RATE.
39 |
40 | #### Example
41 |
42 | ```bash
43 | [+] GOOS: linux
44 | [+] GOARCH: amd64
45 | [+] Command: /path/to/your/app
46 | Available PPROF Config Options:
47 | MEM_PROFILE_DIR - directory to write memory profiles to
48 | CPU_PROFILE_DIR - directory to write cpu profiles to
49 | PPROF_TIME - polling time for cpu and memory profiles (with unit ex: 10s)
50 | MEM_PROFILE_RATE - memory profiling rate (default 4096)
51 | profile: memory profiling enabled (rate 4096), /path/to/memprofile
52 | profile: ticker enabled (rate 10s)
53 | profile: cpu profiling enabled (ticker 10s)
54 | ```
55 |
56 | #### Note
57 |
58 | - The polling time (PPROF_TIME) should be set according to your application's performance and profiling needs.
59 | - The memory profiling rate (MEM_PROFILE_RATE) controls the granularity of the memory profiling. Higher values provide more detail but consume more resources.
60 |
61 | ### server
62 |
63 | Server is a simple HTTP server listening on a port with pprof and fgprof endpoints.
64 |
65 | #### Environment Variables
66 |
67 | - `PPROF_SERVER_ADDR`: Address to listen on for pprof and fgprof server (default 127.0.0.1:6060).
68 |
69 | #### Endpoints
70 |
71 | - /debug/pprof/
72 | - /debug/pprof/cmdline
73 | - /debug/pprof/profile
74 | - /debug/pprof/profile
75 | - /debug/pprof/symbol
76 | - /debug/pprof/trace
77 | - /debug/fgprof
78 |
79 | #### Example
80 |
81 | ```console
82 | go tool pprof http://127.0.0.1:8086/debug/fgprof
83 | ```
--------------------------------------------------------------------------------
/pprof/pprof.go:
--------------------------------------------------------------------------------
1 | package pprof
2 |
3 | import (
4 | "bytes"
5 | "log"
6 | "os"
7 | "path/filepath"
8 | "runtime"
9 | "runtime/pprof"
10 | "strconv"
11 | "strings"
12 | "time"
13 |
14 | "github.com/projectdiscovery/utils/env"
15 | )
16 |
17 | const (
18 | PPROFSwitchENV = "PPROF"
19 | MemProfileENV = "MEM_PROFILE_DIR"
20 | CPUProfileENV = "CPU_PROFILE_DIR"
21 | PPROFTimeENV = "PPROF_TIME"
22 | MemProfileRate = "MEM_PROFILE_RATE"
23 | )
24 |
25 | func init() {
26 | if env.GetEnvOrDefault(PPROFSwitchENV, 0) == 1 {
27 | startDefaultProfiler()
28 | }
29 | }
30 |
31 | func startDefaultProfiler() {
32 | log.Printf("[+] GOOS: %v\n", runtime.GOOS)
33 | log.Printf("[+] GOARCH: %v\n", runtime.GOARCH)
34 | log.Printf("[+] Command: %v\n", strings.Join(os.Args, " "))
35 | log.Println("Available PPROF Config Options:")
36 | log.Printf("%-16v - directory to write memory profiles to\n", MemProfileENV)
37 | log.Printf("%-16v - directory to write cpu profiles to\n", CPUProfileENV)
38 | log.Printf("%-16v - polling time for cpu and memory profiles (with unit ex: 10s)\n", PPROFTimeENV)
39 | log.Printf("%-16v - memory profiling rate (default 4096)\n", MemProfileRate)
40 |
41 | memProfilesDir := env.GetEnvOrDefault(MemProfileENV, "memdump")
42 | cpuProfilesDir := env.GetEnvOrDefault(CPUProfileENV, "cpuprofile")
43 | pprofTimeDuration := env.GetEnvOrDefault(PPROFTimeENV, time.Duration(3)*time.Second)
44 | pprofRate := env.GetEnvOrDefault(MemProfileRate, 4096)
45 |
46 | _ = os.MkdirAll(memProfilesDir, 0755)
47 | _ = os.MkdirAll(cpuProfilesDir, 0755)
48 |
49 | runtime.MemProfileRate = pprofRate
50 | log.Printf("profile: memory profiling enabled (rate %d), %s\n", runtime.MemProfileRate, memProfilesDir)
51 | log.Printf("profile: ticker enabled (rate %s)\n", pprofTimeDuration)
52 |
53 | // cpu ticker and profiler
54 | go func() {
55 | ticker := time.NewTicker(pprofTimeDuration)
56 | count := 0
57 | buff := bytes.Buffer{}
58 | log.Printf("profile: cpu profiling enabled (ticker %s)\n", pprofTimeDuration)
59 | for {
60 | err := pprof.StartCPUProfile(&buff)
61 | if err != nil {
62 | log.Fatalf("profile: could not start cpu profile: %s\n", err)
63 | }
64 | <-ticker.C
65 | pprof.StopCPUProfile()
66 | if err := os.WriteFile(filepath.Join(cpuProfilesDir, "cpuprofile-t"+strconv.Itoa(count)+".out"), buff.Bytes(), 0755); err != nil {
67 | log.Fatalf("profile: could not write cpu profile: %s\n", err)
68 | }
69 | buff.Reset()
70 | count++
71 | }
72 | }()
73 |
74 | // memory ticker and profiler
75 | go func() {
76 | ticker := time.NewTicker(pprofTimeDuration)
77 | count := 0
78 | log.Printf("profile: memory profiling enabled (ticker %s)\n", pprofTimeDuration)
79 | for {
80 | <-ticker.C
81 | var buff bytes.Buffer
82 | if err := pprof.WriteHeapProfile(&buff); err != nil {
83 | log.Printf("profile: could not write memory profile: %s\n", err)
84 | }
85 | err := os.WriteFile(filepath.ToSlash(filepath.Join(memProfilesDir, "memprofile-t"+strconv.Itoa(count)+".out")), buff.Bytes(), 0755)
86 | if err != nil {
87 | log.Printf("profile: could not write memory profile: %s\n", err)
88 | }
89 | count++
90 | }
91 | }()
92 | }
93 |
--------------------------------------------------------------------------------
/pprof/server.go:
--------------------------------------------------------------------------------
1 | package pprof
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "net/http/pprof"
7 | "runtime"
8 | "time"
9 |
10 | "github.com/felixge/fgprof"
11 | "github.com/projectdiscovery/gologger"
12 | "github.com/projectdiscovery/utils/env"
13 | )
14 |
15 | const (
16 | PPROFServerAddressENV = "PPROF_SERVER_ADDRESS"
17 | )
18 |
19 | type PprofServer struct {
20 | server *http.Server
21 | }
22 |
23 | func NewPprofServer() *PprofServer {
24 | address := env.GetEnvOrDefault(PPROFServerAddressENV, "127.0.0.1:8086")
25 |
26 | mux := http.NewServeMux()
27 | // Default pprof handlers
28 | mux.HandleFunc("/debug/pprof/", pprof.Index)
29 | mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
30 | mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
31 | mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
32 | mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
33 |
34 | // Also add fgprof for more detailed profiling
35 | mux.Handle("/debug/fgprof", fgprof.Handler())
36 |
37 | server := &http.Server{
38 | Addr: address,
39 | Handler: mux,
40 | }
41 | // Enable block and mutex profiling as well
42 | runtime.SetBlockProfileRate(1)
43 | runtime.SetMutexProfileFraction(1)
44 |
45 | return &PprofServer{server: server}
46 | }
47 |
48 | func (p *PprofServer) Start() {
49 | gologger.Info().Msgf("Listening pprof debug server on: %s", p.server.Addr)
50 |
51 | go func() {
52 | if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
53 | gologger.Error().Msgf("pprof server failed to start: %s", err)
54 | }
55 | }()
56 | }
57 |
58 | func (p *PprofServer) Stop() {
59 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
60 | defer cancel()
61 | _ = p.server.Shutdown(ctx)
62 | }
63 |
--------------------------------------------------------------------------------
/process/docker.go:
--------------------------------------------------------------------------------
1 | package process
2 |
3 | import (
4 | "bufio"
5 | "os"
6 | "strings"
7 |
8 | fileutil "github.com/projectdiscovery/utils/file"
9 | )
10 |
11 | // RunningInContainer checks if the process is running in a docker container
12 | // and returns true if it is.
13 | // reference: https://www.baeldung.com/linux/is-process-running-inside-container
14 | func RunningInContainer() (bool, string) {
15 | if fileutil.FileOrFolderExists("/.dockerenv") {
16 | return true, "docker"
17 | }
18 | // fallback and check using controlgroup 1 detect
19 | if !fileutil.FileExists("/proc/1/cgroup") {
20 | return false, ""
21 | }
22 | f, err := os.Open("/proc/1/cgroup")
23 | if err != nil {
24 | return false, ""
25 | }
26 | defer f.Close()
27 | buff := bufio.NewScanner(f)
28 | for buff.Scan() {
29 | if strings.Contains(buff.Text(), "/docker") {
30 | return true, "docker"
31 | }
32 | if strings.Contains(buff.Text(), "/lxc") {
33 | return true, "lxc"
34 | }
35 | }
36 | // fallback and check using controlgroup 2 detect
37 | f2, err := os.Open("/proc/self/mountinfo")
38 | if err != nil {
39 | return false, ""
40 | }
41 | defer f2.Close()
42 | buff2 := bufio.NewScanner(f2)
43 | for buff2.Scan() {
44 | if strings.Contains(buff2.Text(), "/docker") {
45 | return true, "docker"
46 | }
47 | if strings.Contains(buff2.Text(), "/lxc") {
48 | return true, "lxc"
49 | }
50 | }
51 |
52 | return false, ""
53 | }
54 |
--------------------------------------------------------------------------------
/process/process.go:
--------------------------------------------------------------------------------
1 | package process
2 |
3 | import (
4 | stringsutil "github.com/projectdiscovery/utils/strings"
5 | ps "github.com/shirou/gopsutil/v3/process"
6 | )
7 |
8 | // CloseProcesses part
9 | func CloseProcesses(predicate func(process *ps.Process) bool, skipPids map[int32]struct{}) {
10 | processes, err := ps.Processes()
11 | if err != nil {
12 | return
13 | }
14 |
15 | for _, process := range processes {
16 | // skip processes that do not satisfy the predicate
17 | if !predicate(process) {
18 | continue
19 | }
20 | // skip processes that are in the skip list
21 | if _, ok := skipPids[process.Pid]; ok {
22 | continue
23 | }
24 | _ = process.Kill()
25 | }
26 | }
27 |
28 | // FindProcesses finds chrome process running on host
29 | func FindProcesses(predicate func(process *ps.Process) bool) map[int32]struct{} {
30 | processes, _ := ps.Processes()
31 | list := make(map[int32]struct{})
32 | for _, process := range processes {
33 | if predicate(process) {
34 | list[process.Pid] = struct{}{}
35 | if ppid, err := process.Ppid(); err == nil {
36 | list[ppid] = struct{}{}
37 | }
38 | }
39 | }
40 | return list
41 | }
42 |
43 | // IsChromeProcess checks if a process is chrome/chromium
44 | func IsChromeProcess(process *ps.Process) bool {
45 | name, _ := process.Name()
46 | executable, _ := process.Exe()
47 | return stringsutil.ContainsAny(name, "chrome", "chromium") || stringsutil.ContainsAny(executable, "chrome", "chromium")
48 | }
49 |
--------------------------------------------------------------------------------
/proxy/README.md:
--------------------------------------------------------------------------------
1 | ## proxy utils
2 |
3 |
--------------------------------------------------------------------------------
/proxy/burp.go:
--------------------------------------------------------------------------------
1 | package proxyutils
2 |
3 | import (
4 | "bytes"
5 | "crypto/tls"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "net/http"
10 | "net/url"
11 | )
12 |
13 | // IsBurp checks if the target proxy URL is burp suite
14 | func IsBurp(proxyURL string) (bool, error) {
15 | return getURLWithHTTPProxy("http://burpsuite/", proxyURL, func(resp *http.Response) (bool, error) {
16 | if resp.StatusCode != http.StatusOK {
17 | return false, fmt.Errorf("unexpected status code (200 wanted): %d", resp.StatusCode)
18 | }
19 |
20 | body, err := io.ReadAll(resp.Body)
21 | if err != nil {
22 | return false, err
23 | }
24 |
25 | defer resp.Body.Close()
26 |
27 | return bytes.Contains(body, []byte("Burp Suite")), nil
28 | })
29 | }
30 |
31 | // ValidateOne returns the first valid proxy from a list of proxies by setting up a test connection with scanme.sh
32 | func ValidateOne(proxies ...string) (string, error) {
33 | for _, proxy := range proxies {
34 | ok, err := getURLWithHTTPProxy("https://scanme.sh", proxy, func(resp *http.Response) (bool, error) {
35 | if resp.StatusCode != http.StatusOK {
36 | return false, fmt.Errorf("unexpected status code (200 wanted): %d", resp.StatusCode)
37 | }
38 |
39 | body, err := io.ReadAll(resp.Body)
40 | if err != nil {
41 | return false, err
42 | }
43 | defer resp.Body.Close()
44 |
45 | return len(body) > 0, nil
46 | })
47 | if ok {
48 | return proxy, err
49 | }
50 | }
51 |
52 | return "", errors.New("no valid proxy found")
53 | }
54 |
55 | func getURLWithHTTPProxy(targetURL, proxyURL string, checkCallback func(resp *http.Response) (bool, error)) (bool, error) {
56 | URL, err := url.Parse(proxyURL)
57 | if err != nil {
58 | return false, err
59 | }
60 | httpClient := &http.Client{
61 | Transport: &http.Transport{
62 | TLSClientConfig: &tls.Config{
63 | InsecureSkipVerify: true,
64 | },
65 | Proxy: http.ProxyURL(URL),
66 | },
67 | }
68 |
69 | resp, err := httpClient.Get(targetURL)
70 | if err != nil {
71 | return false, err
72 | }
73 |
74 | return checkCallback(resp)
75 | }
76 |
--------------------------------------------------------------------------------
/proxy/proxy.go:
--------------------------------------------------------------------------------
1 | package proxyutils
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "net/url"
8 | "strings"
9 | "time"
10 |
11 | errorutil "github.com/projectdiscovery/utils/errors"
12 | "github.com/remeh/sizedwaitgroup"
13 | )
14 |
15 | type proxyResult struct {
16 | AliveProxy string
17 | Error error
18 | }
19 |
20 | // ProxyProbeConcurrency (default 8)
21 | var ProxyProbeConcurrency = 8
22 |
23 | const (
24 | SOCKS5 = "socks5"
25 | HTTP = "http"
26 | HTTPS = "https"
27 | )
28 |
29 | // GetAnyAliveProxy takes proxies as input and returns the first alive proxy
30 | // or returns error if all of them not alive
31 | func GetAnyAliveProxy(timeoutInSec int, proxies ...string) (string, error) {
32 | sg := sizedwaitgroup.New(ProxyProbeConcurrency)
33 | resChan := make(chan proxyResult, 4)
34 | ctx, cancel := context.WithCancel(context.Background())
35 |
36 | go func() {
37 | for _, v := range proxies {
38 | // skip iterating if alive proxy is found
39 | select {
40 | case <-ctx.Done():
41 | return
42 | default:
43 | proxy, err := GetProxyURL(v)
44 | if err != nil {
45 | resChan <- proxyResult{Error: err}
46 | continue
47 | }
48 | sg.Add()
49 | go func(proxyAddr url.URL) {
50 | defer sg.Done()
51 | select {
52 | case <-ctx.Done():
53 | return
54 | case resChan <- testProxyConn(proxyAddr, timeoutInSec):
55 | cancel()
56 | }
57 | }(proxy)
58 | }
59 | }
60 | sg.Wait()
61 | close(resChan)
62 | }()
63 |
64 | errstack := []string{}
65 | for {
66 | result, ok := <-resChan
67 | if !ok {
68 | break
69 | }
70 | if result.AliveProxy != "" {
71 | // found alive proxy return now
72 | return result.AliveProxy, nil
73 | } else if result.Error != nil {
74 | errstack = append(errstack, result.Error.Error())
75 | }
76 | }
77 |
78 | // all proxies are dead
79 | return "", errorutil.NewWithTag("proxyutils", "all proxies are dead got : %v", strings.Join(errstack, " : "))
80 | }
81 |
82 | // dial and test if proxy is open
83 | func testProxyConn(proxyAddr url.URL, timeoutInSec int) proxyResult {
84 | p := proxyResult{}
85 | if Conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%s", proxyAddr.Hostname(), proxyAddr.Port()), time.Duration(timeoutInSec)*time.Second); err == nil {
86 | _ = Conn.Close()
87 | p.AliveProxy = proxyAddr.String()
88 | } else {
89 | p.Error = err
90 | }
91 | return p
92 | }
93 |
94 | // GetProxyURL returns a Proxy URL after validating if given proxy url is valid
95 | func GetProxyURL(proxyAddr string) (url.URL, error) {
96 | if url, err := url.Parse(proxyAddr); err == nil && isSupportedProtocol(url.Scheme) {
97 | return *url, nil
98 | }
99 | return url.URL{}, errorutil.New("invalid proxy format (It should be http[s]/socks5://[username:password@]host:port)").WithTag("proxyutils")
100 | }
101 |
102 | // isSupportedProtocol checks given protocols are supported
103 | func isSupportedProtocol(value string) bool {
104 | return value == HTTP || value == HTTPS || value == SOCKS5
105 | }
106 |
--------------------------------------------------------------------------------
/proxy/proxy_test.go:
--------------------------------------------------------------------------------
1 | //go:build proxy
2 |
3 | package proxyutils
4 |
5 | // package tests will be executed only with (running proxy is necessary):
6 | // go test -tags proxy
7 |
8 | import (
9 | "testing"
10 |
11 | "github.com/stretchr/testify/require"
12 | )
13 |
14 | const burpURL = "http://127.0.0.1:8080"
15 |
16 | // a local instance of burp community is necessary
17 | func TestIsBurp(t *testing.T) {
18 | ok, err := IsBurp(burpURL)
19 | require.Nil(t, err)
20 | require.True(t, ok)
21 | }
22 |
23 | // a valid proxy is necessary
24 | func TestValidateOne(t *testing.T) {
25 | proxyURL, err := ValidateOne(burpURL)
26 | require.Nil(t, err)
27 | require.Equal(t, burpURL, proxyURL)
28 | }
29 |
--------------------------------------------------------------------------------
/ptr/ptr.go:
--------------------------------------------------------------------------------
1 | package ptr
2 |
3 | // Safe dereferences safely a pointer
4 | // - if the pointer is nil => returns the zero value of the type of the pointer if nil
5 | // - if the pointer is not nil => returns the dereferenced pointer
6 | //
7 | // Example:
8 | //
9 | // var v *int
10 | // var x = ptr.Safe(v)
11 | func Safe[T any](v *T) T {
12 | if v == nil {
13 | return *new(T)
14 | }
15 | return *v
16 | }
17 |
18 | // Of returns pointer of a given generic type
19 | //
20 | // Example:
21 | //
22 | // var v int
23 | // var p = ptr.Of(v)
24 | func Of[T any](v T) *T {
25 | return &v
26 | }
27 |
28 | // When returns pointer of a given generic type
29 | // - if the condition is false => returns nil
30 | // - if the condition is true => returns pointer of the value
31 | //
32 | // Example:
33 | //
34 | // var v bool
35 | // var p = ptr.When(v, v != false)
36 | func When[T any](v T, condition bool) *T {
37 | if !condition {
38 | return nil
39 | }
40 | return &v
41 | }
42 |
--------------------------------------------------------------------------------
/ptr/ptr_test.go:
--------------------------------------------------------------------------------
1 | package ptr
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | )
8 |
9 | func TestSafe(t *testing.T) {
10 | type args[T any] struct {
11 | v *T
12 | }
13 | type testCase[T any] struct {
14 | name string
15 | args args[T]
16 | want T
17 | }
18 | tests := []testCase[int]{
19 | {
20 | name: "struct=>int - NilPointer",
21 | args: args[int]{v: nil},
22 | want: 0,
23 | },
24 | {
25 | name: "struct=>int - NonNilPointer",
26 | args: args[int]{v: new(int)},
27 | want: 0,
28 | },
29 | }
30 |
31 | for _, tt := range tests {
32 | t.Run(tt.name, func(t *testing.T) {
33 | got := Safe(tt.args.v)
34 | require.Equal(t, tt.want, got, "Safe() = %v, want %v", got, tt.want)
35 | })
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/race/README.md:
--------------------------------------------------------------------------------
1 | # raceutil
2 | The package contains various helpers for race
3 |
--------------------------------------------------------------------------------
/race/norace.go:
--------------------------------------------------------------------------------
1 | //go:build !race
2 |
3 | // Package raceutil reports if the Go race detector is enabled.
4 | package raceutil
5 |
6 | // Enabled reports if the race detector is enabled.
7 | const Enabled = false
8 |
--------------------------------------------------------------------------------
/race/race.go:
--------------------------------------------------------------------------------
1 | //go:build race
2 |
3 | // Package raceutil reports if the Go race detector is enabled.
4 | package raceutil
5 |
6 | // Enabled reports if the race detector is enabled.
7 | const Enabled = true
8 |
--------------------------------------------------------------------------------
/rand/number.go:
--------------------------------------------------------------------------------
1 | package rand
2 |
3 | import (
4 | "crypto/rand"
5 | "errors"
6 | "math/big"
7 | crand "math/rand"
8 | )
9 |
10 | // IntN returns a uniform random value in [0, max). It errors if max <= 0.
11 | func IntN(max int) (int, error) {
12 | if max <= 0 {
13 | return 0, errors.New("max can't be <= 0")
14 | }
15 | nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
16 | if err != nil {
17 | return crand.Intn(max), nil
18 | }
19 | return int(nBig.Int64()), nil
20 | }
21 |
--------------------------------------------------------------------------------
/rand/number_test.go:
--------------------------------------------------------------------------------
1 | package rand
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | )
8 |
9 | func TestIntN(t *testing.T) {
10 | type testCase struct {
11 | input int
12 | expectedOk bool
13 | }
14 |
15 | testCases := []testCase{
16 | {input: 10, expectedOk: true},
17 | {input: 0, expectedOk: false},
18 | {input: -10, expectedOk: false},
19 | }
20 |
21 | for _, tc := range testCases {
22 | i, err := IntN(tc.input)
23 | ok := i >= 0 && i <= tc.input && err == nil
24 | require.Equal(t, tc.expectedOk, ok)
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/reader/conn_read_test.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "bytes"
5 | "crypto/tls"
6 | "strings"
7 | "testing"
8 | "time"
9 |
10 | "github.com/stretchr/testify/require"
11 | )
12 |
13 | func TestConnReadN(t *testing.T) {
14 | timeout := time.Duration(5) * time.Second
15 |
16 | t.Run("Test with N as -1", func(t *testing.T) {
17 | reader := strings.NewReader("Hello, World!")
18 | data, err := ConnReadNWithTimeout(reader, -1, timeout)
19 | if err != nil {
20 | t.Errorf("Unexpected error: %v", err)
21 | }
22 | if string(data) != "Hello, World!" {
23 | t.Errorf("Expected 'Hello, World!', got '%s'", string(data))
24 | }
25 | })
26 |
27 | t.Run("Test with N as 0", func(t *testing.T) {
28 | reader := strings.NewReader("Hello, World!")
29 | data, err := ConnReadNWithTimeout(reader, 0, timeout)
30 | if err != nil {
31 | t.Errorf("Unexpected error: %v", err)
32 | }
33 | if len(data) != 0 {
34 | t.Errorf("Expected empty, got '%s'", string(data))
35 | }
36 | })
37 |
38 | t.Run("Test with N greater than MaxReadSize", func(t *testing.T) {
39 | reader := bytes.NewReader(make([]byte, MaxReadSize+1))
40 | _, err := ConnReadNWithTimeout(reader, MaxReadSize+1, timeout)
41 | if err != ErrTooLarge {
42 | t.Errorf("Expected 'ErrTooLarge', got '%v'", err)
43 | }
44 | })
45 |
46 | t.Run("Test with N less than MaxReadSize", func(t *testing.T) {
47 | reader := strings.NewReader("Hello, World!")
48 | data, err := ConnReadNWithTimeout(reader, 5, timeout)
49 | if err != nil {
50 | t.Errorf("Unexpected error: %v", err)
51 | }
52 | if string(data) != "Hello" {
53 | t.Errorf("Expected 'Hello', got '%s'", string(data))
54 | }
55 | })
56 | t.Run("Read From Connection", func(t *testing.T) {
57 | conn, err := tls.Dial("tcp", "projectdiscovery.io:443", &tls.Config{InsecureSkipVerify: true})
58 | _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
59 | require.Nil(t, err, "could not connect to projectdiscovery.io over tls")
60 | defer conn.Close()
61 | _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: projectdiscovery.io\r\nConnection: close\r\n\r\n"))
62 | require.Nil(t, err, "could not write to connection")
63 | data, err := ConnReadNWithTimeout(conn, -1, timeout)
64 | require.Nilf(t, err, "could not read from connection: %s", err)
65 | require.NotEmpty(t, data, "could not read from connection")
66 | })
67 |
68 | t.Run("Read From Connection which times out", func(t *testing.T) {
69 | conn, err := tls.Dial("tcp", "projectdiscovery.io:443", &tls.Config{InsecureSkipVerify: true})
70 | _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
71 | require.Nil(t, err, "could not connect to projectdiscovery.io over tls")
72 | defer conn.Close()
73 | _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: projectdiscovery.io\r\n\r\n"))
74 | require.Nil(t, err, "could not write to connection")
75 | data, err := ConnReadNWithTimeout(conn, -1, timeout)
76 | require.Nilf(t, err, "could not read from connection: %s", err)
77 | require.NotEmpty(t, data, "could not read from connection")
78 | })
79 | }
80 |
--------------------------------------------------------------------------------
/reader/error.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import "errors"
4 |
5 | var ErrTimeout = errors.New("Timeout")
6 |
--------------------------------------------------------------------------------
/reader/examples/keypress/buffered/keypress.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 | "sync"
6 | "time"
7 |
8 | "github.com/projectdiscovery/utils/reader"
9 | stringsutil "github.com/projectdiscovery/utils/strings"
10 | )
11 |
12 | func main() {
13 | stdr := reader.KeyPressReader{
14 | Timeout: time.Duration(5 * time.Second),
15 | Once: &sync.Once{},
16 | }
17 |
18 | stdr.Start()
19 | defer stdr.Stop()
20 |
21 | for {
22 | data := make([]byte, stdr.BufferSize)
23 | n, err := stdr.Read(data)
24 | log.Println(n, err)
25 |
26 | if stringsutil.IsCTRLC(string(data)) {
27 | break
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/reader/examples/keypress/raw/keypress.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 | "sync"
6 | "time"
7 |
8 | "github.com/projectdiscovery/utils/reader"
9 | stringsutil "github.com/projectdiscovery/utils/strings"
10 | )
11 |
12 | func main() {
13 | stdr := reader.KeyPressReader{
14 | Timeout: time.Duration(5 * time.Second),
15 | Once: &sync.Once{},
16 | Raw: true,
17 | }
18 |
19 | stdr.Start()
20 | defer stdr.Stop()
21 |
22 | for {
23 | data := make([]byte, 1)
24 | n, err := stdr.Read(data)
25 | if stringsutil.IsPrintable(string(data)) {
26 | log.Println(n, err)
27 | }
28 |
29 | if stringsutil.IsCTRLC(string(data)) {
30 | break
31 | }
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/reader/frozen_reader.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "io"
5 | "math"
6 | "time"
7 | )
8 |
9 | // FrozenReader is a reader that never returns
10 | type FrozenReader struct{}
11 |
12 | // Read into the buffer
13 | func (reader FrozenReader) Read(p []byte) (n int, err error) {
14 | time.Sleep(math.MaxInt32 * time.Second)
15 | return 0, io.EOF
16 | }
17 |
--------------------------------------------------------------------------------
/reader/frozen_reader_test.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "io"
5 | "os"
6 | "testing"
7 | "time"
8 | )
9 |
10 | func TestFrozenReader(t *testing.T) {
11 | forever := func() {
12 | wrappedStdin := FrozenReader{}
13 | _, err := io.Copy(os.Stdout, wrappedStdin)
14 | if err != nil {
15 | return
16 | }
17 | }
18 | go forever()
19 | <-time.After(10 * time.Second)
20 | }
21 |
--------------------------------------------------------------------------------
/reader/rawmode/raw_mode.go:
--------------------------------------------------------------------------------
1 | package rawmode
2 |
3 | import (
4 | "os"
5 | )
6 |
7 | var (
8 | // GetMode from file descriptor
9 | GetMode func(std *os.File) (interface{}, error)
10 | // SetMode to file descriptor
11 | SetMode func(std *os.File, mode interface{}) error
12 | // SetRawMode to file descriptor enriching existign mode with raw console flags
13 | SetRawMode func(std *os.File, mode interface{}) error
14 | // Read from file descriptor to buffer
15 | Read func(std *os.File, buf []byte) (int, error)
16 |
17 | TCSETS uintptr
18 | TCGETS uintptr
19 | )
20 |
--------------------------------------------------------------------------------
/reader/rawmode/raw_mode_posix.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || linux
2 |
3 | package rawmode
4 |
5 | import (
6 | "errors"
7 | "os"
8 | "syscall"
9 | "unsafe"
10 | )
11 |
12 | func init() {
13 | GetMode = func(std *os.File) (interface{}, error) {
14 | return getMode(std)
15 | }
16 |
17 | SetMode = func(std *os.File, mode interface{}) error {
18 | m, ok := mode.(*syscall.Termios)
19 | if !ok {
20 | return errors.New("invalid syscall.Termios")
21 | }
22 | return setMode(std, m)
23 | }
24 |
25 | SetRawMode = func(std *os.File, mode interface{}) error {
26 | m, ok := mode.(*syscall.Termios)
27 | if !ok {
28 | return errors.New("invalid syscall.Termios")
29 | }
30 | return setRawMode(std, m)
31 | }
32 |
33 | Read = func(std *os.File, buf []byte) (int, error) {
34 | return read(std, buf)
35 | }
36 | }
37 |
38 | func getTermios(fd uintptr) (*syscall.Termios, error) {
39 | var t syscall.Termios
40 | _, _, err := syscall.Syscall6(
41 | syscall.SYS_IOCTL,
42 | os.Stdin.Fd(),
43 | TCGETS,
44 | uintptr(unsafe.Pointer(&t)),
45 | 0, 0, 0)
46 |
47 | return &t, err
48 | }
49 |
50 | func setTermios(fd uintptr, term *syscall.Termios) error {
51 | _, _, err := syscall.Syscall6(
52 | syscall.SYS_IOCTL,
53 | os.Stdin.Fd(),
54 | TCSETS,
55 | uintptr(unsafe.Pointer(term)),
56 | 0, 0, 0)
57 | return err
58 | }
59 |
60 | func setRaw(term *syscall.Termios) {
61 | // This attempts to replicate the behaviour documented for cfmakeraw in
62 | // the termios(3) manpage.
63 | term.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
64 | term.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
65 | term.Cflag &^= syscall.CSIZE | syscall.PARENB
66 | term.Cflag |= syscall.CS8
67 |
68 | term.Cc[syscall.VMIN] = 1
69 | term.Cc[syscall.VTIME] = 0
70 | }
71 |
72 | func getMode(std *os.File) (*syscall.Termios, error) {
73 | return getTermios(os.Stdin.Fd())
74 | }
75 |
76 | func setMode(std *os.File, mode *syscall.Termios) error {
77 | return setTermios(os.Stdin.Fd(), mode)
78 | }
79 |
80 | func setRawMode(std *os.File, mode *syscall.Termios) error {
81 | setRaw(mode)
82 | return SetMode(std, mode)
83 | }
84 |
85 | func read(std *os.File, buf []byte) (int, error) {
86 | return syscall.Read(int(os.Stdin.Fd()), buf)
87 | }
88 |
--------------------------------------------------------------------------------
/reader/rawmode/raw_mode_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package rawmode
4 |
5 | import (
6 | "errors"
7 | "os"
8 | "syscall"
9 | "unsafe"
10 | )
11 |
12 | var (
13 | // load kernel32 lib
14 | kernel32 = syscall.NewLazyDLL("kernel32.dll")
15 |
16 | // get handlers to console API
17 | procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
18 | procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
19 | )
20 |
21 | const (
22 | enableLineInput = 2
23 | enableEchoInput = 4
24 | enableProcessedInput = 1
25 | enableWindowInput = 8 //nolint
26 | enableMouseInput = 16 //nolint
27 | enableInsertMode = 32 //nolint
28 | enableQuickEditMode = 64 //nolint
29 | enableExtendedFlags = 128 //nolint
30 | enableAutoPosition = 256 //nolint
31 | enableProcessedOutput = 1 //nolint
32 | enableWrapAtEolOutput = 2 //nolint
33 | )
34 |
35 | func init() {
36 | GetMode = func(std *os.File) (interface{}, error) {
37 | return getMode(std)
38 | }
39 |
40 | SetMode = func(std *os.File, mode interface{}) error {
41 | m, ok := mode.(uint32)
42 | if !ok {
43 | return errors.New("invalid syscall.Termios")
44 | }
45 | return setMode(std, m)
46 | }
47 |
48 | SetRawMode = func(std *os.File, mode interface{}) error {
49 | m, ok := mode.(uint32)
50 | if !ok {
51 | return errors.New("invalid syscall.Termios")
52 | }
53 | return setRawMode(std, m)
54 | }
55 |
56 | Read = func(std *os.File, buf []byte) (int, error) {
57 | return read(std, buf)
58 | }
59 | }
60 |
61 | func getTermMode(fd uintptr) (uint32, error) {
62 | var mode uint32
63 | _, _, err := syscall.SyscallN(
64 | procGetConsoleMode.Addr(),
65 | fd,
66 | uintptr(unsafe.Pointer(&mode)),
67 | 0)
68 | if err != 0 {
69 | return mode, err
70 | }
71 | return mode, nil
72 | }
73 |
74 | func setTermMode(fd uintptr, mode uint32) error {
75 | _, _, err := syscall.SyscallN(
76 | procSetConsoleMode.Addr(),
77 | fd,
78 | uintptr(mode),
79 | 0)
80 | if err != 0 {
81 | return err
82 | }
83 | return nil
84 | }
85 |
86 | // GetMode from file descriptor
87 | func getMode(std *os.File) (uint32, error) {
88 | return getTermMode(os.Stdin.Fd())
89 | }
90 |
91 | // SetMode to file descriptor
92 | func setMode(std *os.File, mode uint32) error {
93 | return setTermMode(os.Stdin.Fd(), mode)
94 | }
95 |
96 | // SetRawMode to file descriptor enriching existign mode with raw console flags
97 | func setRawMode(std *os.File, mode uint32) error {
98 | mode &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
99 | return SetMode(std, mode)
100 | }
101 |
102 | // Read from file descriptor to buffer
103 | func read(std *os.File, buf []byte) (int, error) {
104 | return syscall.Read(syscall.Handle(os.Stdin.Fd()), buf)
105 | }
106 |
--------------------------------------------------------------------------------
/reader/rawmode/values_darwin.go:
--------------------------------------------------------------------------------
1 | //go:build darwin
2 |
3 | package rawmode
4 |
5 | import "syscall"
6 |
7 | func init() {
8 | TCSETS = syscall.TIOCGETA
9 | TCGETS = syscall.TIOCSETA
10 | }
11 |
--------------------------------------------------------------------------------
/reader/rawmode/values_linux.go:
--------------------------------------------------------------------------------
1 | //go:build linux
2 |
3 | package rawmode
4 |
5 | import "syscall"
6 |
7 | func init() {
8 | TCSETS = syscall.TCGETS
9 | TCGETS = syscall.TCSETS
10 | }
11 |
--------------------------------------------------------------------------------
/reader/reader_keypress.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "context"
5 | "os"
6 | "sync"
7 | "time"
8 |
9 | "github.com/projectdiscovery/utils/reader/rawmode"
10 | )
11 |
12 | type KeyPressReader struct {
13 | originalMode interface{}
14 | Timeout time.Duration
15 | datachan chan []byte
16 | Once *sync.Once
17 | Raw bool
18 | BufferSize int
19 | }
20 |
21 | func (reader *KeyPressReader) Start() error {
22 | reader.Once.Do(func() {
23 | go reader.read()
24 | reader.originalMode, _ = rawmode.GetMode(os.Stdin)
25 | if reader.Raw {
26 | reader.BufferSize = 1
27 | } else {
28 | reader.BufferSize = 512
29 | }
30 | })
31 | // set raw mode
32 | if reader.Raw {
33 | mode, _ := rawmode.GetMode(os.Stdin)
34 | return rawmode.SetRawMode(os.Stdin, mode)
35 | }
36 |
37 | // proceed with buffered input - only new lines are detected
38 | return nil
39 | }
40 |
41 | func (reader *KeyPressReader) Stop() error {
42 | // disable raw mode
43 | if reader.Raw {
44 | return rawmode.SetMode(os.Stdin, reader.originalMode)
45 | }
46 |
47 | // nop
48 | return nil
49 | }
50 |
51 | func (reader *KeyPressReader) read() {
52 | if reader.datachan == nil {
53 | reader.datachan = make(chan []byte)
54 | }
55 |
56 | for {
57 | var (
58 | n int
59 | err error
60 | r = make([]byte, reader.BufferSize)
61 | )
62 |
63 | if reader.Raw {
64 | n, err = rawmode.Read(os.Stdin, r)
65 | } else {
66 | n, err = os.Stdin.Read(r)
67 | }
68 | if n > 0 && err == nil {
69 | reader.datachan <- r
70 | }
71 | }
72 | }
73 |
74 | // Read into the buffer
75 | func (reader KeyPressReader) Read(p []byte) (n int, err error) {
76 | var (
77 | ctx context.Context
78 | cancel context.CancelFunc
79 | )
80 | if reader.Timeout > 0 {
81 | ctx, cancel = context.WithTimeout(context.Background(), time.Duration(reader.Timeout))
82 | defer cancel()
83 | }
84 |
85 | select {
86 | case <-ctx.Done():
87 | err = ErrTimeout
88 | return
89 | case data := <-reader.datachan:
90 | n = copy(p, data)
91 | return
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/reader/reusable_read_closer.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "strings"
9 | "sync"
10 | )
11 |
12 | // ReusableReadCloser is a reusable reader with no-op close
13 | type ReusableReadCloser struct {
14 | *sync.RWMutex
15 | io.Reader
16 | readBuf *bytes.Buffer
17 | backBuf *bytes.Buffer
18 | }
19 |
20 | // NewReusableReadCloser is returned for any type of input
21 | func NewReusableReadCloser(raw interface{}) (*ReusableReadCloser, error) {
22 | readBuf := bytes.Buffer{}
23 | backBuf := bytes.Buffer{}
24 | if raw != nil {
25 | switch body := raw.(type) {
26 |
27 | case []byte:
28 | // if a byte array , create buffer from bytes and use it
29 | readBuf = *bytes.NewBuffer(body)
30 |
31 | case *[]byte:
32 | // if *[]byte, create buffer from bytes and use it
33 | readBuf = *bytes.NewBuffer(*body)
34 |
35 | case string:
36 | // if a string , create buffer from string and use it
37 | readBuf = *bytes.NewBufferString(body)
38 |
39 | case *bytes.Buffer:
40 | // if *bytes.Buffer is given , use it
41 | readBuf = *body
42 |
43 | case *bytes.Reader:
44 | // if *bytes.Reader , make buffer read from reader
45 | if _, er := readBuf.ReadFrom(body); er != nil {
46 | return nil, er
47 | }
48 |
49 | case *strings.Reader:
50 | // if *strings.Reader , make buffer read from reader
51 | if _, er := readBuf.ReadFrom(body); er != nil {
52 | return nil, er
53 | }
54 |
55 | case io.ReadSeeker:
56 | // if io.ReadSeeker , make buffer read from reader
57 | if _, er := readBuf.ReadFrom(body); er != nil {
58 | return nil, er
59 | }
60 |
61 | case io.Reader:
62 | // if io.Reader , make buffer read from reader
63 | if _, er := readBuf.ReadFrom(body); er != nil {
64 | return nil, er
65 | }
66 | default:
67 | // type not implemented or cannot handle
68 | return nil, fmt.Errorf("cannot handle type %T", body)
69 | }
70 |
71 | }
72 | reusableReadCloser := &ReusableReadCloser{
73 | &sync.RWMutex{},
74 | io.TeeReader(&readBuf, &backBuf),
75 | &readBuf,
76 | &backBuf,
77 | }
78 |
79 | return reusableReadCloser, nil
80 |
81 | }
82 |
83 | // Read []byte from Reader
84 | func (r ReusableReadCloser) Read(p []byte) (int, error) {
85 | r.Lock()
86 | defer r.Unlock()
87 |
88 | n, err := r.Reader.Read(p)
89 | if errors.Is(err, io.EOF) {
90 | r.reset()
91 | }
92 | return n, err
93 | }
94 |
95 | func (r ReusableReadCloser) reset() {
96 | _, _ = io.Copy(r.readBuf, r.backBuf)
97 | }
98 |
99 | // Close is a no-op close of ReusableReadCloser
100 | func (r ReusableReadCloser) Close() error {
101 | return nil
102 | }
103 |
--------------------------------------------------------------------------------
/reader/reusable_read_closer_test.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "strings"
7 | "sync"
8 | "testing"
9 |
10 | "github.com/stretchr/testify/require"
11 | )
12 |
13 | func TestReusableReader(t *testing.T) {
14 | testcases := []interface{}{
15 | strings.NewReader("test"),
16 | bytes.NewBuffer([]byte("test")),
17 | bytes.NewBufferString("test"),
18 | bytes.NewReader([]byte("test")),
19 | []byte("test"),
20 | "test",
21 | }
22 | for _, v := range testcases {
23 | t.Run("sequential reuse", func(t *testing.T) {
24 | reusableReader, err := NewReusableReadCloser(v)
25 | require.Nil(t, err)
26 |
27 | for i := 0; i < 100; i++ {
28 | n, err := io.Copy(io.Discard, reusableReader)
29 | require.Nil(t, err)
30 | require.Positive(t, n)
31 |
32 | bin, err := io.ReadAll(reusableReader)
33 | require.Nil(t, err)
34 | require.Len(t, bin, 4)
35 | }
36 | })
37 |
38 | // todo: readers shouldn't be used concurrently, so here we just try to catch pontential read concurring with resets panics
39 | t.Run("concurrent-reset-with-read", func(t *testing.T) {
40 | reusableReader, err := NewReusableReadCloser(v)
41 | require.Nil(t, err)
42 |
43 | var wg sync.WaitGroup
44 |
45 | for i := 0; i < 100; i++ {
46 | wg.Add(1)
47 | go func() {
48 | defer wg.Done()
49 |
50 | _, err := io.Copy(io.Discard, reusableReader)
51 | require.Nil(t, err)
52 | _, err = io.ReadAll(reusableReader)
53 | require.Nil(t, err)
54 | }()
55 | }
56 |
57 | wg.Wait()
58 | })
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/reader/timeout_reader.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "context"
5 | "io"
6 | "time"
7 | )
8 |
9 | // TimeoutReader is a reader wrapper that stops waiting after Timeout
10 | type TimeoutReader struct {
11 | Timeout time.Duration
12 | Reader io.Reader
13 | datachan chan struct{}
14 | }
15 |
16 | // Read into the buffer
17 | func (reader TimeoutReader) Read(p []byte) (n int, err error) {
18 | var (
19 | ctx context.Context
20 | cancel context.CancelFunc
21 | )
22 | if reader.Timeout > 0 {
23 | ctx, cancel = context.WithTimeout(context.Background(), time.Duration(reader.Timeout))
24 | defer cancel()
25 | }
26 |
27 | if reader.datachan == nil {
28 | reader.datachan = make(chan struct{})
29 | }
30 |
31 | go func() {
32 | n, err = reader.Reader.Read(p)
33 | reader.datachan <- struct{}{}
34 | }()
35 |
36 | select {
37 | case <-ctx.Done():
38 | err = ErrTimeout
39 | return
40 | case <-reader.datachan:
41 | return
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/reader/timeout_reader_test.go:
--------------------------------------------------------------------------------
1 | package reader
2 |
3 | import (
4 | "io"
5 | "os"
6 | "testing"
7 | "time"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestTimeoutReader(t *testing.T) {
13 | wrappedStdin := TimeoutReader{
14 | Reader: FrozenReader{},
15 | Timeout: time.Duration(2 * time.Second),
16 | }
17 | _, err := io.Copy(os.Stdout, wrappedStdin)
18 | require.NotNil(t, err)
19 | }
20 |
--------------------------------------------------------------------------------
/reflect/README.md:
--------------------------------------------------------------------------------
1 | # reflectutil
2 | The package contains various helpers for reflection
--------------------------------------------------------------------------------
/reflect/tests/tests.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | type Test struct {
4 | unexported string //nolint
5 | }
6 |
--------------------------------------------------------------------------------
/routing/router_windows.go:
--------------------------------------------------------------------------------
1 | package routing
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "net"
7 | "os/exec"
8 | "strconv"
9 | "strings"
10 |
11 | "github.com/asaskevich/govalidator"
12 | "github.com/pkg/errors"
13 | stringsutil "github.com/projectdiscovery/utils/strings"
14 | )
15 |
16 | // New creates a routing engine for windows
17 | func New() (Router, error) {
18 | var routes []*Route
19 |
20 | for _, iptype := range []RouteType{IPv4, IPv6} {
21 | netshCmd := exec.Command("netsh", "interface", iptype.String(), "show", "route")
22 | netshOutput, err := netshCmd.Output()
23 | if err != nil {
24 | return nil, err
25 | }
26 |
27 | scanner := bufio.NewScanner(bytes.NewReader(netshOutput))
28 | for scanner.Scan() {
29 | outputLine := strings.TrimSpace(scanner.Text())
30 | if outputLine == "" {
31 | continue
32 | }
33 |
34 | parts := stringsutil.SplitAny(outputLine, " \t")
35 | if len(parts) >= 6 && govalidator.IsNumeric(parts[4]) {
36 | prefix := parts[3]
37 | _, _, err := net.ParseCIDR(prefix)
38 | if err != nil {
39 | return nil, err
40 | }
41 | gateway := parts[5]
42 | interfaceIndex, err := strconv.Atoi(parts[4])
43 | if err != nil {
44 | return nil, err
45 | }
46 |
47 | networkInterface, err := net.InterfaceByIndex(interfaceIndex)
48 | if err != nil {
49 | return nil, err
50 | }
51 | isDefault := stringsutil.EqualFoldAny(prefix, "0.0.0.0/0", "::/0")
52 |
53 | route := &Route{
54 | Type: iptype,
55 | Default: isDefault,
56 | Destination: prefix,
57 | Gateway: gateway,
58 | NetworkInterface: networkInterface,
59 | }
60 |
61 | routes = append(routes, route)
62 | }
63 | }
64 | }
65 |
66 | return &RouterWindows{Routes: routes}, nil
67 | }
68 |
69 | type RouterWindows struct {
70 | Routes []*Route
71 | }
72 |
73 | func (r *RouterWindows) Route(dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
74 | route, err := FindRouteForIp(dst, r.Routes)
75 | if err != nil {
76 | return nil, nil, nil, errors.Wrap(err, "could not find route")
77 | }
78 |
79 | if route.NetworkInterface == nil {
80 | return nil, nil, nil, errors.Wrap(err, "could not find network interface")
81 | }
82 | ip, err := FindSourceIpForIp(route, dst)
83 | if err != nil {
84 | return nil, nil, nil, errors.Wrap(err, "could not find source ip")
85 | }
86 |
87 | return route.NetworkInterface, net.IP(route.Gateway), ip, nil
88 | }
89 |
90 | func (r *RouterWindows) RouteWithSrc(input net.HardwareAddr, src, dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
91 | route, err := FindRouteWithHwAndIp(input, src, r.Routes)
92 | if err != nil {
93 | return nil, nil, nil, err
94 | }
95 |
96 | return route.NetworkInterface, net.IP(route.Gateway), src, nil
97 | }
98 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # scripts
2 | The package contains various scripts
3 |
4 |
5 | ## versionbump
6 | This Go script can automatically bump the semantic version number defined in a Go source file. It parses the specified Go source file with `go/ast`, finds the given variable (which is assumed to contain a semantic version string), increments the specified part of the version number (major, minor, or patch) with `github.com/Masterminds/semver/v3`, and rewrites the file with the updated version.
7 |
8 | ```
9 | go run versionbump.go -file /path/to/your/file.go -var YourVersionVariable
10 | ```
11 |
12 | By default, the patch version is incremented. To increment the major or minor versions instead, specify -part major or -part minor respectively:
13 |
14 | ```
15 | go run versionbump.go -file /path/to/your/file.go -var YourVersionVariable -part minor
16 | ```
17 |
--------------------------------------------------------------------------------
/scripts/versionbump/versionbump.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "go/ast"
7 | "go/format"
8 | "go/parser"
9 | "go/token"
10 | "os"
11 | "path/filepath"
12 | "strconv"
13 |
14 | semver "github.com/Masterminds/semver/v3"
15 | )
16 |
17 | func bumpVersion(fileName, varName, part string) (string, string, error) {
18 | absPath, err := filepath.Abs(fileName)
19 | if err != nil {
20 | return "", "", fmt.Errorf("unable to get absolute path: %v", err)
21 | }
22 |
23 | fset := token.NewFileSet()
24 | node, err := parser.ParseFile(fset, absPath, nil, parser.ParseComments)
25 | if err != nil {
26 | return "", "", fmt.Errorf("could not parse file: %w", err)
27 | }
28 |
29 | var oldVersion, newVersion string
30 |
31 | ast.Inspect(node, func(n ast.Node) bool {
32 | if v, ok := n.(*ast.GenDecl); ok {
33 | for _, spec := range v.Specs {
34 | if s, ok := spec.(*ast.ValueSpec); ok {
35 | for idx, id := range s.Names {
36 | if id.Name == varName {
37 | oldVersion, _ = strconv.Unquote(s.Values[idx].(*ast.BasicLit).Value)
38 | v, err := semver.NewVersion(oldVersion)
39 | if err != nil || v.String() == "" {
40 | return false
41 | }
42 | var vInc func() semver.Version
43 | switch part {
44 | case "major":
45 | vInc = v.IncMajor
46 | case "minor":
47 | vInc = v.IncMinor
48 | case "", "patch":
49 | vInc = v.IncPatch
50 | default:
51 | return false
52 | }
53 | newVersion = "v" + vInc().String()
54 | s.Values[idx].(*ast.BasicLit).Value = fmt.Sprintf("`%s`", newVersion)
55 | return false
56 | }
57 | }
58 | }
59 | }
60 | }
61 | return true
62 | })
63 |
64 | if newVersion == "" {
65 | return oldVersion, newVersion, fmt.Errorf("failed to update the version")
66 | }
67 |
68 | f, err := os.OpenFile(fileName, os.O_RDWR, 0666)
69 | if err != nil {
70 | return oldVersion, newVersion, fmt.Errorf("could not open file: %w", err)
71 | }
72 | defer f.Close()
73 |
74 | if err := format.Node(f, fset, node); err != nil {
75 | return oldVersion, newVersion, fmt.Errorf("could not write to file: %w", err)
76 | }
77 |
78 | return oldVersion, newVersion, nil
79 | }
80 |
81 | func main() {
82 | var (
83 | fileName string
84 | varName string
85 | part string
86 | )
87 |
88 | flag.StringVar(&fileName, "file", "", "Go source file to parse")
89 | flag.StringVar(&varName, "var", "", "Variable to update")
90 | flag.StringVar(&part, "part", "patch", "Version part to increment (major, minor, patch)")
91 |
92 | flag.Parse()
93 |
94 | if fileName == "" || varName == "" {
95 | fmt.Println("Error: Both -file and -var are required")
96 | os.Exit(1)
97 | }
98 | oldVersion, newVersion, err := bumpVersion(fileName, varName, part)
99 | if err != nil {
100 | fmt.Printf("Error bumping version: %v\n", err)
101 | os.Exit(1)
102 | }
103 | fmt.Printf("Bump from %s to %s\n", oldVersion, newVersion)
104 | }
105 |
--------------------------------------------------------------------------------
/slice/README.md:
--------------------------------------------------------------------------------
1 | # sliceutil
2 | The package contains various helpers to interact with slices
--------------------------------------------------------------------------------
/slice/sync_slice.go:
--------------------------------------------------------------------------------
1 | package sliceutil
2 |
3 | import "sync"
4 |
5 | // SyncSlice provides a thread-safe slice for elements of any comparable type.
6 | type SyncSlice[K comparable] struct {
7 | Slice []K
8 | mu *sync.RWMutex
9 | }
10 |
11 | // NewSyncSlice initializes a new instance of SyncSlice.
12 | func NewSyncSlice[K comparable]() *SyncSlice[K] {
13 | return &SyncSlice[K]{mu: &sync.RWMutex{}}
14 | }
15 |
16 | // Append adds elements to the end of the slice in a thread-safe manner.
17 | func (ss *SyncSlice[K]) Append(items ...K) {
18 | ss.mu.Lock()
19 | defer ss.mu.Unlock()
20 |
21 | ss.Slice = append(ss.Slice, items...)
22 | }
23 |
24 | // Each iterates over all elements in the slice and applies the function f to each element.
25 | // Iteration is done in a read-locked context to prevent data race.
26 | func (ss *SyncSlice[K]) Each(f func(i int, k K) error) {
27 | ss.mu.RLock()
28 | defer ss.mu.RUnlock()
29 |
30 | for i, k := range ss.Slice {
31 | if err := f(i, k); err != nil {
32 | break
33 | }
34 | }
35 | }
36 |
37 | // Empty clears the slice by reinitializing it in a thread-safe manner.
38 | func (ss *SyncSlice[K]) Empty() {
39 | ss.mu.Lock()
40 | defer ss.mu.Unlock()
41 |
42 | ss.Slice = make([]K, 0)
43 | }
44 |
45 | // Len returns the number of elements in the slice in a thread-safe manner.
46 | func (ss *SyncSlice[K]) Len() int {
47 | ss.mu.RLock()
48 | defer ss.mu.RUnlock()
49 |
50 | return len(ss.Slice)
51 | }
52 |
53 | // Get retrieves an element by index from the slice safely.
54 | // Returns the element and true if index is within bounds, otherwise returns zero value and false.
55 | func (ss *SyncSlice[K]) Get(index int) (K, bool) {
56 | ss.mu.RLock()
57 | defer ss.mu.RUnlock()
58 |
59 | if index < 0 || index >= len(ss.Slice) {
60 | var zero K
61 | return zero, false
62 | }
63 | return ss.Slice[index], true
64 | }
65 |
66 | // Put updates the element at the specified index in the slice in a thread-safe manner.
67 | // Returns true if the index is within bounds, otherwise false.
68 | func (ss *SyncSlice[K]) Put(index int, value K) bool {
69 | ss.mu.Lock()
70 | defer ss.mu.Unlock()
71 |
72 | if index < 0 || index >= len(ss.Slice) {
73 | return false
74 | }
75 | ss.Slice[index] = value
76 | return true
77 | }
78 |
--------------------------------------------------------------------------------
/slice/sync_slice_test.go:
--------------------------------------------------------------------------------
1 | package sliceutil
2 |
3 | import (
4 | "sync"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestSimpleUsage(t *testing.T) {
10 | ss := NewSyncSlice[int]()
11 | expected := 10
12 | for i := 0; i < expected; i++ {
13 | ss.Append(i)
14 | }
15 | value, ok := ss.Get(5)
16 | if !ok {
17 | t.Errorf("Failed to get value at index 5")
18 | } else if value != 5 {
19 | t.Errorf("Expected value 5 at index 5, got %d", value)
20 | }
21 |
22 | success := ss.Put(5, 20)
23 | if !success {
24 | t.Errorf("Failed to put value at index 5")
25 | }
26 |
27 | value, ok = ss.Get(5)
28 | if !ok {
29 | t.Errorf("Failed to get value at index 5 after put")
30 | } else if value != 20 {
31 | t.Errorf("Expected value 20 at index 5 after put, got %d", value)
32 | }
33 | if ss.Len() != expected {
34 | t.Errorf("Expected slice length %d, got %d", expected, ss.Len())
35 | }
36 | ss.Empty()
37 | if ss.Len() != 0 {
38 | t.Errorf("Expected slice length 0 after emptying, got %d", ss.Len())
39 | }
40 | }
41 |
42 | func TestConcurrentAppend(t *testing.T) {
43 | ss := NewSyncSlice[int]()
44 | var wg sync.WaitGroup
45 | count := 1000
46 |
47 | for i := 0; i < count; i++ {
48 | wg.Add(1)
49 | go func(val int) {
50 | defer wg.Done()
51 | ss.Append(val)
52 |
53 | if val%10 == 0 {
54 | ss.Put(val, val*2) // Double the value at positions that are multiples of 10
55 | }
56 | if val%5 == 0 {
57 | retrievedVal, _ := ss.Get(val) // Attempt to get the value at positions that are multiples of 5
58 | _ = retrievedVal // Use the retrieved value to ensure it's not optimized away
59 | }
60 | }(i)
61 | }
62 | wg.Wait()
63 |
64 | if ss.Len() != count {
65 | t.Errorf("Expected slice length %d after concurrent append, got %d", count, ss.Len())
66 | }
67 | }
68 |
69 | func TestConcurrentReadWriteAndIteration(t *testing.T) {
70 | ss := NewSyncSlice[int]()
71 | var wg sync.WaitGroup
72 | readWriteCount := 1000
73 |
74 | wg.Add(3) // Adding three groups: writer, reader, iterator
75 |
76 | // Writer goroutine
77 | go func() {
78 | defer wg.Done()
79 | for i := 0; i < readWriteCount; i++ {
80 | ss.Append(i) // Write
81 | }
82 | }()
83 |
84 | // Reader goroutine
85 | go func() {
86 | defer wg.Done()
87 |
88 | time.Sleep(250 * time.Millisecond)
89 |
90 | for i := 0; i < readWriteCount; i++ {
91 | if value, ok := ss.Get(i % ss.Len()); !ok {
92 | t.Errorf("Failed to get value at index %d", i%ss.Len())
93 | } else {
94 | _ = value // Use the value to ensure it's not optimized away
95 | }
96 | }
97 | }()
98 |
99 | // Iterator goroutine
100 | go func() {
101 | defer wg.Done()
102 | for repeat := 0; repeat < 1000; repeat++ { // Repeat the iteration 1000 times
103 | ss.Each(func(index int, value int) error {
104 | // Simulate some processing
105 | _ = index
106 | _ = value
107 | return nil
108 | })
109 | }
110 | }()
111 |
112 | wg.Wait()
113 |
114 | if ss.Len() != readWriteCount {
115 | t.Errorf("Expected slice length %d after concurrent read/write, got %d", readWriteCount, ss.Len())
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/strings/README.md:
--------------------------------------------------------------------------------
1 | # stringsutil
2 | The package contains various helpers to interact with strings
--------------------------------------------------------------------------------
/strings/strings_encoding.go:
--------------------------------------------------------------------------------
1 | package stringsutil
2 |
3 | import (
4 | "errors"
5 |
6 | "github.com/saintfish/chardet"
7 | )
8 |
9 | type EncodingType uint8
10 |
11 | const (
12 | Unknown EncodingType = iota
13 | UTF8
14 | UTF16BE
15 | UTF16LE
16 | UTF32BE
17 | UTF32LE
18 | ISO85591
19 | ISO88592
20 | ISO88595
21 | ISO88596
22 | ISO88597
23 | ISO88598
24 | Windows1251
25 | Windows1256
26 | KOI8R
27 | ShiftJIS
28 | GB18030
29 | EUCJP
30 | EUCKR
31 | Big5
32 | ISO2022JP
33 | ISO2022KR
34 | ISO2022CN
35 | IBM424rtl
36 | IBM424ltr
37 | IBM420rtl
38 | IBM420ltr
39 | )
40 |
41 | var detector *chardet.Detector = chardet.NewTextDetector()
42 |
43 | func DetectEncodingType(data interface{}) (EncodingType, error) {
44 | var (
45 | enc *chardet.Result
46 | err error
47 | )
48 | switch dd := data.(type) {
49 | case string:
50 | enc, err = detector.DetectBest([]byte(dd))
51 | case []byte:
52 | enc, err = detector.DetectBest(dd)
53 | default:
54 | return Unknown, errors.New("unsupported type")
55 | }
56 |
57 | if err != nil || enc == nil {
58 | return Unknown, err
59 | }
60 |
61 | switch enc.Charset {
62 | case "UTF-8":
63 | return UTF8, nil
64 | case "UTF-16BE":
65 | return UTF16BE, nil
66 | case "UTF-16LE":
67 | return UTF16LE, nil
68 | case "UTF-32BE":
69 | return UTF32BE, nil
70 | case "UTF-32LE":
71 | return UTF32LE, nil
72 | case "ISO-8859-1":
73 | return ISO85591, nil
74 | case "ISO-8859-2":
75 | return ISO88592, nil
76 | case "ISO-8859-5":
77 | return ISO88595, nil
78 | case "ISO-8859-6":
79 | return ISO88596, nil
80 | case "ISO-8859-7":
81 | return ISO88597, nil
82 | case "ISO-8859-8":
83 | return ISO88598, nil
84 | case "windows-1251":
85 | return Windows1251, nil
86 | case "windows-1256":
87 | return Windows1256, nil
88 | case "KOI8-R":
89 | return KOI8R, nil
90 | case "Shift_JIS":
91 | return ShiftJIS, nil
92 | case "GB18030":
93 | return GB18030, nil
94 | case "EUC-JP":
95 | return EUCJP, nil
96 | case "EUC-KR":
97 | return EUCKR, nil
98 | case "Big5":
99 | return Big5, nil
100 | case "ISO-2022-JP":
101 | return ISO2022JP, nil
102 | case "ISO-2022-KR":
103 | return ISO2022KR, nil
104 | case "ISO-2022-CN":
105 | return ISO2022CN, nil
106 | case "IBM424_rtl":
107 | return IBM424rtl, nil
108 | case "IBM424_ltr":
109 | return IBM424ltr, nil
110 | case "IBM420_rtl":
111 | return IBM420rtl, nil
112 | case "IBM420_ltr":
113 | return IBM420ltr, nil
114 | default:
115 | return Unknown, nil
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/strings/strings_normalize.go:
--------------------------------------------------------------------------------
1 | package stringsutil
2 |
3 | import (
4 | "strings"
5 | "unicode"
6 |
7 | "github.com/microcosm-cc/bluemonday"
8 | )
9 |
10 | type NormalizeOptions struct {
11 | TrimSpaces bool
12 | TrimCutset string
13 | StripHTML bool
14 | Lowercase bool
15 | Uppercase bool
16 | StripComments bool
17 | }
18 |
19 | var DefaultNormalizeOptions NormalizeOptions = NormalizeOptions{
20 | TrimSpaces: true,
21 | StripHTML: true,
22 | }
23 |
24 | var HTMLPolicy *bluemonday.Policy = bluemonday.StrictPolicy()
25 |
26 | func NormalizeWithOptions(data string, options NormalizeOptions) string {
27 | if options.TrimSpaces {
28 | data = strings.TrimSpace(data)
29 | }
30 |
31 | if options.TrimCutset != "" {
32 | data = strings.Trim(data, options.TrimCutset)
33 | }
34 |
35 | if options.Lowercase {
36 | data = strings.ToLower(data)
37 | }
38 |
39 | if options.Uppercase {
40 | data = strings.ToUpper(data)
41 | }
42 |
43 | if options.StripHTML {
44 | data = HTMLPolicy.Sanitize(data)
45 | }
46 |
47 | if options.StripComments {
48 | if cut := strings.IndexAny(data, "#"); cut >= 0 {
49 | data = strings.TrimRightFunc(data[:cut], unicode.IsSpace)
50 | }
51 | }
52 |
53 | return data
54 | }
55 |
56 | func Normalize(data string) string {
57 | return NormalizeWithOptions(data, DefaultNormalizeOptions)
58 | }
59 |
--------------------------------------------------------------------------------
/structs/structs_test.go:
--------------------------------------------------------------------------------
1 | package structs
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 | )
7 |
8 | type TestStruct struct {
9 | Name string
10 | Age int
11 | Address string
12 | }
13 |
14 | type NestedStruct struct {
15 | Basic TestStruct
16 | PtrField *TestStruct
17 | }
18 |
19 | func TestFilterStruct(t *testing.T) {
20 | s := TestStruct{
21 | Name: "John",
22 | Age: 30,
23 | Address: "New York",
24 | }
25 |
26 | tests := []struct {
27 | name string
28 | input interface{}
29 | includeFields []string
30 | excludeFields []string
31 | want TestStruct
32 | wantErr bool
33 | }{
34 | {
35 | name: "include specific fields",
36 | input: s,
37 | includeFields: []string{"name", "Age"},
38 | excludeFields: []string{},
39 | want: TestStruct{
40 | Name: "John",
41 | Age: 30,
42 | },
43 | wantErr: false,
44 | },
45 | {
46 | name: "exclude specific fields",
47 | input: s,
48 | includeFields: []string{},
49 | excludeFields: []string{"address"},
50 | want: TestStruct{
51 | Name: "John",
52 | Age: 30,
53 | },
54 | wantErr: false,
55 | },
56 | {
57 | name: "non-struct input",
58 | input: "not a struct",
59 | includeFields: []string{},
60 | excludeFields: []string{},
61 | want: TestStruct{},
62 | wantErr: true,
63 | },
64 | }
65 |
66 | for _, tt := range tests {
67 | t.Run(tt.name, func(t *testing.T) {
68 | got, err := FilterStruct(tt.input, tt.includeFields, tt.excludeFields)
69 | if (err != nil) != tt.wantErr {
70 | t.Errorf("FilterStruct() error = %v, wantErr %v", err, tt.wantErr)
71 | return
72 | }
73 | if !tt.wantErr {
74 | if !reflect.DeepEqual(got, tt.want) {
75 | t.Errorf("FilterStruct() = %v, want %v", got, tt.want)
76 | }
77 | }
78 | })
79 | }
80 | }
81 |
82 | func TestGetStructFields(t *testing.T) {
83 | s := TestStruct{
84 | Name: "John",
85 | Age: 30,
86 | Address: "New York",
87 | }
88 |
89 | tests := []struct {
90 | name string
91 | input interface{}
92 | want []string
93 | wantErr bool
94 | }{
95 | {
96 | name: "valid struct",
97 | input: s,
98 | want: []string{"name", "age", "address"},
99 | wantErr: false,
100 | },
101 | {
102 | name: "non-struct input",
103 | input: "not a struct",
104 | want: nil,
105 | wantErr: true,
106 | },
107 | }
108 |
109 | for _, tt := range tests {
110 | t.Run(tt.name, func(t *testing.T) {
111 | got, err := GetStructFields(tt.input)
112 | if (err != nil) != tt.wantErr {
113 | t.Errorf("GetStructFields() error = %v, wantErr %v", err, tt.wantErr)
114 | return
115 | }
116 | if !tt.wantErr {
117 | if !reflect.DeepEqual(got, tt.want) {
118 | t.Errorf("GetStructFields() = %v, want %v", got, tt.want)
119 | }
120 | }
121 | })
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/sync/adaptivewaitgroup.go:
--------------------------------------------------------------------------------
1 | package sync
2 |
3 | // Extended version of https://github.com/remeh/sizedwaitgroup
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "sync"
9 | "sync/atomic"
10 |
11 | "github.com/projectdiscovery/utils/sync/semaphore"
12 | )
13 |
14 | type AdaptiveGroupOption func(*AdaptiveWaitGroup) error
15 |
16 | type AdaptiveWaitGroup struct {
17 | Size int
18 | current *atomic.Int64
19 |
20 | sem *semaphore.Semaphore
21 | wg sync.WaitGroup
22 | mu sync.Mutex // Mutex to protect access to the Size and semaphore
23 | }
24 |
25 | // WithSize sets the initial size of the waitgroup ()
26 | func WithSize(size int) AdaptiveGroupOption {
27 | return func(wg *AdaptiveWaitGroup) error {
28 | if err := validateSize(size); err != nil {
29 | return err
30 | }
31 | sem, err := semaphore.New(int64(size))
32 | if err != nil {
33 | return err
34 | }
35 | wg.sem = sem
36 | wg.Size = size
37 | return nil
38 | }
39 | }
40 |
41 | func validateSize(size int) error {
42 | if size < 1 {
43 | return errors.New("size must be at least 1")
44 | }
45 | return nil
46 | }
47 |
48 | func New(options ...AdaptiveGroupOption) (*AdaptiveWaitGroup, error) {
49 | wg := &AdaptiveWaitGroup{}
50 | for _, option := range options {
51 | if err := option(wg); err != nil {
52 | return nil, err
53 | }
54 | }
55 |
56 | wg.wg = sync.WaitGroup{}
57 | wg.current = &atomic.Int64{}
58 | return wg, nil
59 | }
60 |
61 | func (s *AdaptiveWaitGroup) Add() {
62 | _ = s.AddWithContext(context.Background())
63 | }
64 |
65 | func (s *AdaptiveWaitGroup) AddWithContext(ctx context.Context) error {
66 | select {
67 | case <-ctx.Done():
68 | return ctx.Err()
69 | default:
70 | // Attempt to acquire a semaphore slot, handle error if acquisition fails
71 | if err := s.sem.Acquire(ctx, 1); err != nil {
72 | return err
73 | }
74 | }
75 |
76 | // Safely add to the waitgroup only after acquiring the semaphore
77 | s.wg.Add(1)
78 | s.current.Add(1)
79 | return nil
80 | }
81 |
82 | func (s *AdaptiveWaitGroup) Done() {
83 | s.sem.Release(1)
84 | s.wg.Done()
85 | s.current.Add(-1)
86 | }
87 |
88 | func (s *AdaptiveWaitGroup) Wait() {
89 | s.wg.Wait()
90 | }
91 |
92 | func (s *AdaptiveWaitGroup) Resize(ctx context.Context, size int) error {
93 | s.mu.Lock()
94 | defer s.mu.Unlock()
95 |
96 | if err := validateSize(size); err != nil {
97 | return err
98 | }
99 |
100 | // Resize the semaphore with the provided context and handle any errors
101 | if err := s.sem.Resize(ctx, int64(size)); err != nil {
102 | return err
103 | }
104 | s.Size = size
105 | return nil
106 | }
107 |
108 | func (s *AdaptiveWaitGroup) Current() int {
109 | return int(s.current.Load())
110 | }
111 |
--------------------------------------------------------------------------------
/sync/semaphore/semaphore.go:
--------------------------------------------------------------------------------
1 | package semaphore
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "math"
7 | "sync/atomic"
8 |
9 | "golang.org/x/sync/semaphore"
10 | )
11 |
12 | type Semaphore struct {
13 | sem *semaphore.Weighted
14 | initialSize atomic.Int64
15 | maxSize atomic.Int64
16 | currentSize atomic.Int64
17 | }
18 |
19 | func New(size int64) (*Semaphore, error) {
20 | maxSize := int64(math.MaxInt64)
21 | s := &Semaphore{
22 | sem: semaphore.NewWeighted(maxSize),
23 | }
24 | s.initialSize.Store(size)
25 | s.maxSize.Store(maxSize)
26 | s.currentSize.Store(size)
27 | err := s.sem.Acquire(context.Background(), s.maxSize.Load()-s.initialSize.Load())
28 | return s, err
29 | }
30 |
31 | func (s *Semaphore) Acquire(ctx context.Context, n int64) error {
32 | return s.sem.Acquire(ctx, n)
33 | }
34 |
35 | func (s *Semaphore) Release(n int64) {
36 | s.sem.Release(n)
37 | }
38 |
39 | // Vary capacity by x - it's internally enqueued as a normal Acquire/Release operation as other Get/Put
40 | // but tokens are held internally
41 | func (s *Semaphore) Vary(ctx context.Context, x int64) error {
42 | switch {
43 | case x > 0:
44 | s.sem.Release(x)
45 | s.currentSize.Add(x)
46 | return nil
47 | case x < 0:
48 | err := s.sem.Acquire(ctx, x)
49 | if err != nil {
50 | return err
51 | }
52 | s.currentSize.Add(x)
53 | return nil
54 | default:
55 | return errors.New("x is zero")
56 | }
57 | }
58 |
59 | func (s *Semaphore) Resize(ctx context.Context, newSize int64) error {
60 | currentSize := s.currentSize.Load()
61 | difference := newSize - currentSize
62 |
63 | if difference == 0 {
64 | return nil // No resizing needed if the new size is the same as the current size
65 | }
66 |
67 | if difference > 0 {
68 | // Increase capacity
69 | s.sem.Release(difference)
70 | } else {
71 | // Decrease capacity
72 | err := s.sem.Acquire(ctx, -difference) // Acquire takes a positive number, so negate difference
73 | if err != nil {
74 | return err
75 | }
76 | }
77 |
78 | s.currentSize.Store(newSize)
79 | return nil
80 | }
81 |
82 | // Current size of the semaphore
83 | func (s *Semaphore) Size() int64 {
84 | return s.currentSize.Load()
85 | }
86 |
87 | // Nominal size of the sempahore
88 | func (s *Semaphore) InitialSize() int64 {
89 | return s.initialSize.Load()
90 | }
91 |
--------------------------------------------------------------------------------
/sync/sizedpool/sizedpool.go:
--------------------------------------------------------------------------------
1 | package sizedpool
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "sync"
7 |
8 | "github.com/projectdiscovery/utils/sync/semaphore"
9 | )
10 |
11 | type PoolOption[T any] func(*SizedPool[T]) error
12 |
13 | func WithSize[T any](size int64) PoolOption[T] {
14 | return func(sz *SizedPool[T]) error {
15 | if size <= 0 {
16 | return errors.New("size must be positive")
17 | }
18 | var err error
19 | sz.sem, err = semaphore.New(size)
20 | if err != nil {
21 | return err
22 | }
23 | return nil
24 | }
25 | }
26 |
27 | func WithPool[T any](p *sync.Pool) PoolOption[T] {
28 | return func(sz *SizedPool[T]) error {
29 | sz.pool = p
30 | return nil
31 | }
32 | }
33 |
34 | type SizedPool[T any] struct {
35 | sem *semaphore.Semaphore
36 | pool *sync.Pool
37 | }
38 |
39 | func New[T any](options ...PoolOption[T]) (*SizedPool[T], error) {
40 | sz := &SizedPool[T]{}
41 | for _, option := range options {
42 | if err := option(sz); err != nil {
43 | return nil, err
44 | }
45 | }
46 | return sz, nil
47 | }
48 |
49 | func (sz *SizedPool[T]) Get(ctx context.Context) (T, error) {
50 | if sz.sem != nil {
51 | if err := sz.sem.Acquire(ctx, 1); err != nil {
52 | var t T
53 | return t, err
54 | }
55 | }
56 | return sz.pool.Get().(T), nil
57 | }
58 |
59 | func (sz *SizedPool[T]) Put(x T) {
60 | if sz.sem != nil {
61 | sz.sem.Release(1)
62 | }
63 | sz.pool.Put(x)
64 | }
65 |
66 | // Vary capacity by x - it's internally enqueued as a normal Acquire/Release operation as other Get/Put
67 | // but tokens are held internally
68 | func (sz *SizedPool[T]) Vary(ctx context.Context, x int64) error {
69 | return sz.sem.Vary(ctx, x)
70 | }
71 |
72 | // Current size of the pool
73 | func (sz *SizedPool[T]) Size() int64 {
74 | return sz.sem.Size()
75 | }
76 |
--------------------------------------------------------------------------------
/sync/sizedpool/sizedpool_test.go:
--------------------------------------------------------------------------------
1 | package sizedpool
2 |
3 | import (
4 | "context"
5 | "sync"
6 | "testing"
7 | "time"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | type testStruct struct{}
13 |
14 | func TestSizedPool(t *testing.T) {
15 | p := &sync.Pool{
16 | New: func() any {
17 | return &testStruct{}
18 | },
19 | }
20 |
21 | // Create a new SizedPool with a max capacity of 2
22 | pool, err := New[*testStruct](
23 | WithSize[*testStruct](2),
24 | WithPool[*testStruct](p),
25 | )
26 | if err != nil {
27 | t.Errorf("Error creating pool: %v", err)
28 | }
29 |
30 | // Test Get and Put operations
31 | ctx := context.Background()
32 | obj1, err := pool.Get(ctx)
33 | require.Nil(t, err)
34 | require.NotNil(t, obj1)
35 |
36 | obj2, err := pool.Get(ctx)
37 | require.Nil(t, err)
38 | require.NotNil(t, obj2)
39 |
40 | go func() {
41 | time.Sleep(3 * time.Second)
42 | pool.Put(obj1)
43 | time.Sleep(1 * time.Second)
44 | pool.Put(obj2)
45 | }()
46 |
47 | start := time.Now()
48 | obj3, _ := pool.Get(ctx)
49 | require.WithinDuration(t, start.Add(3*time.Second), time.Now(), 500*time.Millisecond)
50 | require.NotNil(t, obj3)
51 | }
52 |
53 | func TestSizedPoolVary(t *testing.T) {
54 | p := &sync.Pool{
55 | New: func() any {
56 | return &testStruct{}
57 | },
58 | }
59 |
60 | // Create a new SizedPool with a max capacity of 2
61 | pool, err := New[*testStruct](
62 | WithSize[*testStruct](2),
63 | WithPool[*testStruct](p),
64 | )
65 | if err != nil {
66 | t.Errorf("Error creating pool: %v", err)
67 | }
68 |
69 | // Test Get and Put operations
70 | ctx := context.Background()
71 | obj1, err := pool.Get(ctx)
72 | require.Nil(t, err)
73 | require.NotNil(t, obj1)
74 |
75 | obj2, err := pool.Get(ctx)
76 | require.Nil(t, err)
77 | require.NotNil(t, obj2)
78 |
79 | var wg sync.WaitGroup
80 |
81 | wg.Add(1)
82 | go func() {
83 | defer wg.Done()
84 |
85 | obj3, err := pool.Get(context.Background())
86 | require.Nil(t, err)
87 | require.NotNil(t, obj3)
88 | }()
89 |
90 | err = pool.Vary(context.Background(), 1)
91 | require.Nil(t, err)
92 |
93 | wg.Wait()
94 | }
95 |
--------------------------------------------------------------------------------
/syscallutil/syscall_unix.go:
--------------------------------------------------------------------------------
1 | //go:build (darwin || linux) && !(386 || arm)
2 |
3 | package syscallutil
4 |
5 | import "github.com/ebitengine/purego"
6 |
7 | func loadLibrary(name string) (uintptr, error) {
8 | return purego.Dlopen(name, purego.RTLD_NOW|purego.RTLD_GLOBAL)
9 | }
10 |
--------------------------------------------------------------------------------
/syscallutil/syscall_unix_others.go:
--------------------------------------------------------------------------------
1 | //go:build (darwin || linux) && (386 || arm)
2 |
3 | package syscallutil
4 |
5 | import "errors"
6 |
7 | func loadLibrary(name string) (uintptr, error) {
8 | return 0, errors.New("not implemented")
9 | }
10 |
--------------------------------------------------------------------------------
/syscallutil/syscallutil.go:
--------------------------------------------------------------------------------
1 | package syscallutil
2 |
3 | func LoadLibrary(name string) (uintptr, error) {
4 | return loadLibrary(name)
5 | }
6 |
--------------------------------------------------------------------------------
/syscallutil/syscallutil_test.go:
--------------------------------------------------------------------------------
1 | package syscallutil
2 |
3 | import (
4 | "fmt"
5 | "runtime"
6 | "testing"
7 |
8 | osutils "github.com/projectdiscovery/utils/os"
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestLoadLibrary(t *testing.T) {
13 | t.Run("Test valid library", func(t *testing.T) {
14 | var lib string
15 | if osutils.IsWindows() {
16 | lib = "ucrtbase.dll"
17 | } else if osutils.IsOSX() {
18 | lib = "libSystem.dylib"
19 | } else if osutils.IsLinux() {
20 | lib = "libc.so.6"
21 | } else {
22 | panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS))
23 | }
24 |
25 | _, err := LoadLibrary(lib)
26 | require.NoError(t, err, "should not return an error for valid library")
27 | })
28 |
29 | t.Run("Test invalid library", func(t *testing.T) {
30 | var lib string
31 | if osutils.IsWindows() {
32 | lib = "C:\\path\\to\\invalid\\library.dll"
33 | } else if osutils.IsOSX() {
34 | lib = "/path/to/invalid/library.dylib"
35 | } else if osutils.IsLinux() {
36 | lib = "/path/to/invalid/library.so"
37 | } else {
38 | panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS))
39 | }
40 |
41 | _, err := LoadLibrary(lib)
42 | require.Error(t, err, "should return an error for invalid library")
43 | })
44 | }
45 |
--------------------------------------------------------------------------------
/syscallutil/syscallutil_win.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package syscallutil
4 |
5 | import "golang.org/x/sys/windows"
6 |
7 | func loadLibrary(name string) (uintptr, error) {
8 | handle, err := windows.LoadLibrary(name)
9 | return uintptr(handle), err
10 | }
11 |
--------------------------------------------------------------------------------
/sysutil/sysutil.go:
--------------------------------------------------------------------------------
1 | package sysutil
2 |
3 | import (
4 | "runtime/debug"
5 | )
6 |
7 | // SetMaxThreads sets the maximum number of operating system
8 | // threads that the Go program can use. If it attempts to use more than
9 | // this many, the program crashes.
10 | // SetMaxThreads returns the previous setting.
11 | // The initial setting is 10,000 threads.
12 | //
13 | // The limit controls the number of operating system threads, not the number
14 | // of goroutines. A Go program creates a new thread only when a goroutine
15 | // is ready to run but all the existing threads are blocked in system calls, cgo calls,
16 | // or are locked to other goroutines due to use of runtime.LockOSThread.
17 | //
18 | // SetMaxThreads is useful mainly for limiting the damage done by
19 | // programs that create an unbounded number of threads. The idea is
20 | // to take down the program before it takes down the operating system.
21 | func SetMaxThreads(threads int) int {
22 | return debug.SetMaxThreads(threads)
23 | }
24 |
--------------------------------------------------------------------------------
/sysutil/sysutil_test.go:
--------------------------------------------------------------------------------
1 | package sysutil
2 |
3 | import (
4 | "runtime/debug"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestSetMaxThreads(t *testing.T) {
11 | originalMaxThreads := debug.SetMaxThreads(10000)
12 | defer debug.SetMaxThreads(originalMaxThreads)
13 |
14 | newMaxThreads := 5000
15 | previousMaxThreads := SetMaxThreads(newMaxThreads)
16 | require.Equal(t, 10000, previousMaxThreads, "Expected previous max threads to be 10000")
17 | require.Equal(t, newMaxThreads, debug.SetMaxThreads(newMaxThreads), "Expected max threads to be set to 5000")
18 |
19 | SetMaxThreads(originalMaxThreads)
20 | }
21 |
--------------------------------------------------------------------------------
/time/README.md:
--------------------------------------------------------------------------------
1 | # timeutil
2 | The package contains various helpers to interact with time
--------------------------------------------------------------------------------
/time/timeutil.go:
--------------------------------------------------------------------------------
1 | package timeutil
2 |
3 | import (
4 | "fmt"
5 | "strconv"
6 | "strings"
7 | "time"
8 | )
9 |
10 | // RFC3339ToTime converts RFC3339 (standard extended go) to time
11 | func RFC3339ToTime(s interface{}) (time.Time, error) {
12 | return time.Parse(time.RFC3339, fmt.Sprint(s))
13 | }
14 |
15 | // MsToTime converts uint64/int64 milliseconds to go time.Time
16 | func MsToTime(i64 interface{}) time.Time {
17 | // 1ms = 1000000ns
18 | switch v := i64.(type) {
19 | case int64:
20 | return time.Unix(0, v*1000000)
21 | case uint64:
22 | return time.Unix(0, int64(v)*1000000)
23 | case string:
24 | return MsToTime(stringToInt(fmt.Sprint(i64)))
25 | }
26 | return time.Time{}
27 | }
28 |
29 | func SToTime(i64 interface{}) time.Time {
30 | switch v := i64.(type) {
31 | case int64:
32 | return time.Unix(v, 0)
33 | case uint64:
34 | return time.Unix(int64(v), 0)
35 | case string:
36 | return SToTime(stringToInt(fmt.Sprint(i64)))
37 | }
38 | return time.Now()
39 | }
40 |
41 | func stringToInt(s string) interface{} {
42 | if u, err := strconv.ParseInt(s, 0, 64); err == nil {
43 | return u
44 | }
45 | if u, err := strconv.ParseUint(s, 0, 64); err == nil {
46 | return u
47 | }
48 |
49 | return 0
50 | }
51 |
52 | func ParseUnixTimestamp(s string) (time.Time, error) {
53 | i, err := strconv.ParseInt(s, 10, 64)
54 | if err != nil {
55 | return time.Time{}, err
56 | }
57 | return time.Unix(i, 0), nil
58 | }
59 |
60 | // ParseDuration is similar to time.ParseDuration but also supports days unit
61 | // if the unit is omitted, it defaults to seconds
62 | func ParseDuration(s string) (time.Duration, error) {
63 | s = strings.ToLower(s)
64 | // default to sec
65 | if _, err := strconv.Atoi(s); err == nil {
66 | s = s + "s"
67 | }
68 | // parse days unit as hours
69 | if strings.HasSuffix(s, "d") {
70 | s = strings.TrimSuffix(s, "d")
71 | if days, err := strconv.Atoi(s); err == nil {
72 | s = strconv.Itoa(days*24) + "h"
73 | }
74 | }
75 | return time.ParseDuration(s)
76 | }
77 |
--------------------------------------------------------------------------------
/time/timeutil_test.go:
--------------------------------------------------------------------------------
1 | package timeutil
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestRFC3339ToTime(t *testing.T) {
11 | orig := time.Now()
12 | // converts back
13 | tt, err := RFC3339ToTime(orig.Format(time.RFC3339))
14 | require.Nil(t, err, "couldn't parse string time")
15 | require.Equal(t, orig.Unix(), tt.Unix(), "times don't match")
16 | }
17 |
18 | func TestMsToTime(t *testing.T) {
19 | // TBD in chaos + bbsh
20 | }
21 |
22 | func TestSToTime(t *testing.T) {
23 | // TBD in chaos + bbsh
24 | }
25 |
26 | func TestParseDuration(t *testing.T) {
27 | tt, err := ParseDuration("2d")
28 | require.Nil(t, err, "couldn't parse duration")
29 | require.Equal(t, time.Hour*24*2, tt, "times don't match")
30 |
31 | tt, err = ParseDuration("2")
32 | require.Nil(t, err, "couldn't parse duration")
33 | require.Equal(t, time.Second*2, tt, "times don't match")
34 | }
35 |
--------------------------------------------------------------------------------
/trace/trace_test.go:
--------------------------------------------------------------------------------
1 | package trace
2 |
3 | import (
4 | "testing"
5 | "time"
6 | )
7 |
8 | func TestFunctionWithBeforeFunction(t *testing.T) {
9 | var beforeCalled bool
10 | _, _ = Trace(func() {
11 | if !beforeCalled {
12 | t.Errorf("Before function was not called before the main function")
13 | }
14 | }, WithBefore(func() {
15 | beforeCalled = true
16 | }))
17 |
18 | if !beforeCalled {
19 | t.Errorf("Before function was not called")
20 | }
21 | }
22 |
23 | func TestFunctionWithAfterFunction(t *testing.T) {
24 | var afterCalled bool
25 | _, _ = Trace(func() {
26 | if afterCalled {
27 | t.Errorf("After function was called before the main function finished")
28 | }
29 | }, WithAfter(func() {
30 | afterCalled = true
31 | }))
32 |
33 | if !afterCalled {
34 | t.Errorf("After function was not called")
35 | }
36 | }
37 |
38 | func TestFunctionTracing(t *testing.T) {
39 | metrics, _ := Trace(func() {
40 | time.Sleep(2 * time.Second)
41 | })
42 |
43 | if metrics.ExecutionDuration.Seconds() < 2 {
44 | t.Errorf("ExecutionDuration is less than expected: %v", metrics.ExecutionDuration)
45 | }
46 |
47 | if len(metrics.Snapshots) == 0 {
48 | t.Errorf("Memory snapshots are not captured")
49 | }
50 |
51 | if metrics.MinAllocMemory == 0 {
52 | t.Errorf("MinMemory not computed")
53 | }
54 |
55 | if metrics.MaxAllocMemory == 0 {
56 | t.Errorf("MaxMemory not computed")
57 | }
58 |
59 | if metrics.AvgAllocMemory == 0 {
60 | t.Errorf("AvgMemory not computed")
61 | }
62 | }
63 |
64 | func TestFunctionWithCustomStrategy(t *testing.T) {
65 | var customLogs []string
66 | metrics, _ := Trace(func() {
67 | time.Sleep(1 * time.Second)
68 | }, WithStrategy(&CustomStrategy{metrics: &Metrics{}, logs: &customLogs}))
69 |
70 | if len(customLogs) != 2 {
71 | t.Errorf("Custom logs not captured as expected")
72 | }
73 |
74 | if customLogs[0] != "Custom Before method started." {
75 | t.Errorf("Expected custom log for Before method not found")
76 | }
77 |
78 | if customLogs[1] != "Custom After method executed." {
79 | t.Errorf("Expected custom log for After method not found")
80 | }
81 |
82 | if metrics.ExecutionDuration.Seconds() < 1 {
83 | t.Errorf("ExecutionDuration is less than expected: %v", metrics.ExecutionDuration)
84 | }
85 |
86 | if len(metrics.Snapshots) != 0 {
87 | t.Errorf("Custom strategy should not capture snapshots")
88 | }
89 | }
90 |
91 | type CustomStrategy struct {
92 | metrics *Metrics
93 | logs *[]string
94 | }
95 |
96 | func (c *CustomStrategy) Before() {
97 | *c.logs = append(*c.logs, "Custom Before method started.")
98 | c.metrics.StartTime = time.Now()
99 | }
100 |
101 | func (c *CustomStrategy) After() {
102 | *c.logs = append(*c.logs, "Custom After method executed.")
103 | c.metrics.FinishTime = time.Now()
104 | c.metrics.ExecutionDuration = c.metrics.FinishTime.Sub(c.metrics.StartTime)
105 | }
106 |
107 | func (c *CustomStrategy) GetMetrics() *Metrics {
108 | return c.metrics
109 | }
110 |
--------------------------------------------------------------------------------
/unit/doc.go:
--------------------------------------------------------------------------------
1 | // unit contains common values:
2 | // - units size (ex. byte, kilo, mega, giga)
3 | package unit
4 |
--------------------------------------------------------------------------------
/unit/size.go:
--------------------------------------------------------------------------------
1 | package unit
2 |
3 | const (
4 | Byte = 1
5 | Kilo = 1024
6 | Mega = 1024 * 1024
7 | Giga = 1024 * 1024 * 1024
8 | )
9 |
--------------------------------------------------------------------------------
/update/gh_test.go:
--------------------------------------------------------------------------------
1 | //go:build update
2 |
3 | // update related tests are only executed when update tag is provided (ex: go test -tags update ./...) to avoid failures due to rate limiting
4 | package updateutils
5 |
6 | import (
7 | "io"
8 | "io/fs"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/require"
12 | )
13 |
14 | // TestDownloadNucleiRelease tests downloading nuclei release
15 | func TestDownloadNucleiRelease(t *testing.T) {
16 | HideProgressBar = true
17 | gh, err := NewghReleaseDownloader("nuclei")
18 | require.Nil(t, err)
19 | _, err = gh.GetExecutableFromAsset()
20 | require.Nil(t, err)
21 | }
22 |
23 | // TestDownloadNucleiTemplatesFromSource tests downloading nuclei-templates from source
24 | func TestDownloadNucleiTemplatesFromSource(t *testing.T) {
25 | gh, err := NewghReleaseDownloader("nuclei-templates")
26 | require.Nil(t, err)
27 | counter := 0
28 | callback := func(path string, fileInfo fs.FileInfo, data io.Reader) error {
29 | _ = fileInfo.Name()
30 | counter++
31 | return nil
32 | }
33 | err = gh.DownloadSourceWithCallback(false, callback)
34 | require.Nil(t, err)
35 | // actual content is lot more than 100 files
36 | require.Greater(t, counter, 100)
37 | }
38 |
39 | // TestDownloadToolWithDifferentName tests downloading a tool with different name than repo name
40 | // by default repo name is considered as executable name
41 | func TestDownloadToolWithDifferentName(t *testing.T) {
42 | gh, err := NewghReleaseDownloader("interactsh")
43 | require.Nil(t, err)
44 | gh.SetToolName("interactsh-client")
45 | _, err = gh.GetExecutableFromAsset()
46 | require.Nil(t, err)
47 | }
48 |
--------------------------------------------------------------------------------
/update/types_test.go:
--------------------------------------------------------------------------------
1 | package updateutils
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestIsOutdated(t *testing.T) {
11 | tests := []struct {
12 | current string
13 | latest string
14 | expected bool
15 | }{
16 | {
17 | current: "1.0.0",
18 | latest: "1.1.0",
19 | expected: true,
20 | },
21 | {
22 | current: "1.0.0",
23 | latest: "1.0.0",
24 | expected: false,
25 | },
26 | {
27 | current: "1.1.0",
28 | latest: "1.0.0",
29 | expected: false,
30 | },
31 | {
32 | current: "1.0.0-dev",
33 | latest: "1.0.0",
34 | expected: true,
35 | },
36 | {
37 | current: "invalid",
38 | latest: "1.0.0",
39 | expected: true,
40 | },
41 | {
42 | current: "invalid1",
43 | latest: "invalid2",
44 | expected: true,
45 | },
46 | {
47 | current: "1.0.0-alpha",
48 | latest: "1.0.0",
49 | expected: true,
50 | },
51 | {
52 | current: "1.0.0-alpha",
53 | latest: "1.0.0-beta",
54 | expected: true,
55 | },
56 | }
57 |
58 | for _, tt := range tests {
59 | t.Run(fmt.Sprintf("current: %v, latest: %v", tt.current, tt.latest), func(t *testing.T) {
60 | assert.Equal(t, tt.expected, IsOutdated(tt.current, tt.latest), "version comparison failed")
61 | })
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/update/utils_all.go:
--------------------------------------------------------------------------------
1 | //go:build !linux
2 | // +build !linux
3 |
4 | package updateutils
5 |
6 | import (
7 | "runtime"
8 | )
9 |
10 | // Get OS Vendor returns the linux distribution vendor
11 | // if not linux then returns runtime.GOOS
12 | func GetOSVendor() string {
13 | return runtime.GOOS
14 | }
15 |
--------------------------------------------------------------------------------
/update/utils_linux.go:
--------------------------------------------------------------------------------
1 | //go:build linux
2 | // +build linux
3 |
4 | package updateutils
5 |
6 | import (
7 | "github.com/zcalusic/sysinfo"
8 | )
9 |
10 | // Get OS Vendor returns the linux distribution vendor
11 | // if not linux then returns runtime.GOOS
12 | func GetOSVendor() string {
13 | var si sysinfo.SysInfo
14 | si.GetSysInfo()
15 | return si.OS.Vendor
16 | }
17 |
--------------------------------------------------------------------------------
/update/utils_test.go:
--------------------------------------------------------------------------------
1 | package updateutils
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/logrusorgru/aurora"
7 | )
8 |
9 | func TestGetVersionDescription(t *testing.T) {
10 | Aurora = aurora.NewAurora(false)
11 | tests := []struct {
12 | current string
13 | latest string
14 | want string
15 | }{
16 | {
17 | current: "v2.9.1-dev",
18 | latest: "v2.9.1",
19 | want: "(outdated)",
20 | },
21 | {
22 | current: "v2.9.1-dev",
23 | latest: "v2.9.2",
24 | want: "(outdated)",
25 | },
26 | {
27 | current: "v2.9.1-dev",
28 | latest: "v2.9.0",
29 | want: "(development)",
30 | },
31 | {
32 | current: "v2.9.1",
33 | latest: "v2.9.1",
34 | want: "(latest)",
35 | },
36 | {
37 | current: "v2.9.1",
38 | latest: "v2.9.2",
39 | want: "(outdated)",
40 | },
41 | }
42 | for _, test := range tests {
43 | if GetVersionDescription(test.current, test.latest) != test.want {
44 | t.Errorf("GetVersionDescription(%v, %v) = %v, want %v", test.current, test.latest, GetVersionDescription(test.current, test.latest), test.want)
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/url/orderedparams_test.go:
--------------------------------------------------------------------------------
1 | package urlutil
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "net/url"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestOrderedParam(t *testing.T) {
13 | p := NewOrderedParams()
14 | p.Add("sqli", "1+AND+(SELECT+*+FROM+(SELECT(SLEEP(12)))nQIP)")
15 | p.Add("xss", "")
16 | p.Add("xssiwthspace", "