├── .DS_Store
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── FAQ.md
├── LICENSE
├── README.md
├── c2pa
├── README.md
├── c2pa_watermark_example.py
├── example.jpg
├── example_wm.jpg
├── example_wm_signed.jpg
├── keys
│ ├── README.md
│ ├── es256_certs.pem
│ └── es256_private.key
└── manifest.json
├── images
├── ghost.png
├── ghost_P.png
├── ghost_Q.png
├── ripley.jpg
├── ripley_P.png
├── ripley_Q.png
├── ufo_240.jpg
├── ufo_240_P.png
└── ufo_240_Q.png
├── js
├── README.md
├── deps
│ ├── bch_ecc.min.js
│ ├── ort-wasm-simd-threaded.jsep.mjs
│ ├── ort-wasm-simd-threaded.jsep.wasm
│ ├── ort-wasm-simd-threaded.mjs
│ ├── ort-wasm-simd-threaded.wasm
│ ├── ort.webgpu.min.js
│ ├── ort.webgpu.min.js.map
│ ├── ort.webgpu.min.mjs
│ └── ort.webgpu.min.mjs.map
├── index.html
├── tm_datalayer.js
└── tm_watermark.js
├── python
├── CONFIG.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements.txt
├── setup.py
├── test.py
└── trustmark
│ ├── KBNet
│ ├── README.md
│ ├── arch_util.py
│ ├── kb_utils.py
│ ├── kbnet.py
│ ├── kbnet_l_arch.py
│ └── kbnet_s_arch.py
│ ├── __init__.py
│ ├── bchecc.py
│ ├── datalayer.py
│ ├── denoise.py
│ ├── model.py
│ ├── models
│ └── README.md
│ ├── trustmark.py
│ └── unet.py
└── rust
├── .cargo
└── config.toml
├── .gitignore
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── README.md
├── benches
├── encode.rs
├── encode.sh
├── load.rs
└── load.sh
├── crates
├── trustmark-cli
│ ├── Cargo.toml
│ ├── README.md
│ └── src
│ │ └── main.rs
└── xtask
│ ├── Cargo.toml
│ └── src
│ └── main.rs
├── models
└── .gitkeep
└── src
├── bits.rs
├── bits
└── bch.rs
├── image_processing.rs
├── lib.rs
└── model.rs
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/.DS_Store
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Adobe Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our project and community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.
6 |
7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
8 |
9 | ## Our Standards
10 |
11 | Examples of behavior that contribute to a positive environment for our project and community include:
12 |
13 | * Demonstrating empathy and kindness toward other people
14 | * Being respectful of differing opinions, viewpoints, and experiences
15 | * Giving and gracefully accepting constructive feedback
16 | * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
17 | * Focusing on what is best, not just for us as individuals but for the overall community
18 |
19 | Examples of unacceptable behavior include:
20 |
21 | * The use of sexualized language or imagery, and sexual attention or advances of any kind
22 | * Trolling, insulting or derogatory comments, and personal or political attacks
23 | * Public or private harassment
24 | * Publishing others’ private information, such as a physical or email address, without their explicit permission
25 | * Other conduct which could reasonably be considered inappropriate in a professional setting
26 |
27 | ## Our Responsibilities
28 |
29 | Project maintainers are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any instances of unacceptable behavior.
30 |
31 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for behaviors that they deem inappropriate, threatening, offensive, or harmful.
32 |
33 | ## Scope
34 |
35 | This Code of Conduct applies when an individual is representing the project or its community both within project spaces and in public spaces. Examples of representing a project or community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
36 |
37 | ## Enforcement
38 |
39 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by first contacting the project team. Oversight of Adobe projects is handled by the Adobe Open Source Office, which has final say in any violations and enforcement of this Code of Conduct and can be reached at Grp-opensourceoffice@adobe.com. All complaints will be reviewed and investigated promptly and fairly.
40 |
41 | The project team must respect the privacy and security of the reporter of any incident.
42 |
43 | Project maintainers who do not follow or enforce the Code of Conduct may face temporary or permanent repercussions as determined by other members of the project's leadership or the Adobe Open Source Office.
44 |
45 | ## Enforcement Guidelines
46 |
47 | Project maintainers will follow these Community Impact Guidelines in determining the consequences for any action they deem to be in violation of this Code of Conduct:
48 |
49 | **1. Correction**
50 |
51 | Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
52 |
53 | Consequence: A private, written warning from project maintainers describing the violation and why the behavior was unacceptable. A public apology may be requested from the violator before any further involvement in the project by violator.
54 |
55 | **2. Warning**
56 |
57 | Community Impact: A relatively minor violation through a single incident or series of actions.
58 |
59 | Consequence: A written warning from project maintainers that includes stated consequences for continued unacceptable behavior. Violator must refrain from interacting with the people involved for a specified period of time as determined by the project maintainers, including, but not limited to, unsolicited interaction with those enforcing the Code of Conduct through channels such as community spaces and social media. Continued violations may lead to a temporary or permanent ban.
60 |
61 | **3. Temporary Ban**
62 |
63 | Community Impact: A more serious violation of community standards, including sustained unacceptable behavior.
64 |
65 | Consequence: A temporary ban from any interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Failure to comply with the temporary ban may lead to a permanent ban.
66 |
67 | **4. Permanent Ban**
68 |
69 | Community Impact: Demonstrating a consistent pattern of violation of community standards or an egregious violation of community standards, including, but not limited to, sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
70 |
71 | Consequence: A permanent ban from any interaction with the community.
72 |
73 | ## Attribution
74 |
75 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1,
76 | available at [https://contributor-covenant.org/version/2/1][version]
77 |
78 | [homepage]: https://contributor-covenant.org
79 | [version]: https://contributor-covenant.org/version/2/1
80 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Thanks for choosing to contribute!
4 |
5 | The following are a set of guidelines to follow when contributing to this project.
6 |
7 | ## Code Of Conduct
8 |
9 | This project adheres to the Adobe [code of conduct](../CODE_OF_CONDUCT.md). By participating,
10 | you are expected to uphold this code. Please report unacceptable behavior to
11 | [Grp-opensourceoffice@adobe.com](mailto:Grp-opensourceoffice@adobe.com).
12 |
13 | ## Have A Question?
14 |
15 | Start by filing an issue. The existing committers on this project work to reach
16 | consensus around project direction and issue solutions within issue threads
17 | (when appropriate).
18 |
19 | ## Contributor License Agreement
20 |
21 | All third-party contributions to this project must be accompanied by a signed contributor
22 | license agreement. This gives Adobe permission to redistribute your contributions
23 | as part of the project. [Sign our CLA](https://opensource.adobe.com/cla.html). You
24 | only need to submit an Adobe CLA one time, so if you have submitted one previously,
25 | you are good to go!
26 |
27 | ## Code Reviews
28 |
29 | All submissions should come in the form of pull requests and need to be reviewed
30 | by project committers. Read [GitHub's pull request documentation](https://help.github.com/articles/about-pull-requests/)
31 | for more information on sending pull requests.
32 |
33 | Lastly, please follow the [pull request template](PULL_REQUEST_TEMPLATE.md) when
34 | submitting a pull request!
35 |
36 | ## From Contributor To Committer
37 |
38 | We love contributions from our community! If you'd like to go a step beyond contributor
39 | and become a committer with full write access and a say in the project, you must
40 | be invited to the project. The existing committers employ an internal nomination
41 | process that must reach lazy consensus (silence is approval) before invitations
42 | are issued. If you feel you are qualified and want to get more deeply involved,
43 | feel free to reach out to existing committers to have a conversation about that.
44 |
45 | ## Security Issues
46 |
47 | Security issues shouldn't be reported on this issue tracker. Instead, [file an issue to our security experts](https://helpx.adobe.com/security/alertus.html).
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Adobe
2 | All Rights Reserved.
3 |
4 | NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | accordance with the terms of the license agreement accompanying it.
6 |
7 | MIT License
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy
10 | of this software and associated documentation files (the "Software"), to deal
11 | in the Software without restriction, including without limitation the rights
12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | copies of the Software, and to permit persons to whom the Software is
14 | furnished to do so, subject to the following conditions:
15 |
16 | The above copyright notice and this permission notice shall be included in all
17 | copies or substantial portions of the Software.
18 |
19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | SOFTWARE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TrustMark
2 |
3 | This repository contains the official, open source implementation of TrustMark watermarking for the Content Authenticity Initiative (CAI) as described in [**TrustMark - Universal Watermarking for Arbitrary Resolution Images**](https://arxiv.org/abs/2311.18297) (`arXiv:2311.18297`) by [Tu Bui](https://www.surrey.ac.uk/people/tu-bui)[^1], [Shruti Agarwal](https://research.adobe.com/person/shruti-agarwal/)[^2], and [John Collomosse](https://www.collomosse.com)[^1] [^2].
4 |
5 | [^1]: [DECaDE](https://decade.ac.uk/) Centre for the Decentralized Digital Economy, University of Surrey, UK.
6 |
7 | [^2]: [Adobe Research](https://research.adobe.com/), San Jose, CA.
8 |
9 | ## Overview
10 |
11 | This repository contains the following directories:
12 |
13 | - `/python`: Python implementation of TrustMark for encoding, decoding and removing image watermarks (using PyTorch). For information on configuring TrustMark in Python, see [Configuring TrustMark](python/CONFIG.md).
14 | - `/js`: Javascript implementation of TrustMark decoding of image watermarks (using ONNX). For more information, see [TrustMark - JavaScript implementation](js/README.md).
15 | - `/rust`: Rust implementation of TrustMark. for more information, see [TrustMark — Rust implementation](rust/README.md).
16 | - `/c2pa`: Python example of how to indicate the presence of a TrustMark watermark in a C2PA manifest. For more information, see [Using TrustMark with C2PA](c2pa/README.md).
17 |
18 | Model files (**ckpt** PyTorch file for Python and **onnx** ONNX file for JavaScript) are not packaged in this repository due to their size, but are downloaded upon first use. See the code for [URLs and md5 hashes](https://github.com/adobe/trustmark/blob/4ef0dde4abd84d1c6873e7c5024482f849db2c73/python/trustmark/trustmark.py#L30) for a direct download link.
19 |
20 | More information:
21 |
22 | - For answers to common questions, see the [FAQ](FAQ.md).
23 | - For information on configuring TrustMark in Python, see [Configuring TrustMark](python/CONFIG.md).
24 |
25 | ## Installation
26 |
27 | ### Prerequisite
28 |
29 | You must have Python 3.8.5 or higher to use the TrustMark Python implementation.
30 |
31 | ### Installing from PyPI
32 |
33 | The easiest way to install TrustMark is from the [Python Package Index (PyPI)](https://pypi.org/project/trustmark/) by entering this command:
34 |
35 | ```
36 | pip install trustmark
37 | ```
38 |
39 | Alternatively, after you've cloned the repository, you can install from the `python` directory:
40 |
41 | ```
42 | cd trustmark/python
43 | pip install .
44 | ```
45 |
46 | ## Quickstart
47 |
48 | To get started quickly, run the `python/test.py` script that provides examples of watermarking several
49 | image files from the `images` directory.
50 |
51 | ### Run the example
52 |
53 | Run the example as follows:
54 |
55 | ```sh
56 | cd trustmark/python
57 | python test.py
58 | ```
59 |
60 | You'll see output like this:
61 |
62 | ```
63 | Initializing TrustMark watermarking with ECC using [cpu]
64 | Extracted secret: 1000000100001110000010010001011110010001011000100000100110110 (schema 1)
65 | PSNR = 50.357909
66 | No secret after removal
67 | ```
68 |
69 | ### Example script
70 |
71 | The `python/test.py` script provides examples of watermarking a JPEG photo, a JPEG GenAI image, and an RGBA PNG image. The example uses TrustMark variant Q to encode the word `mysecret` in ASCII7 encoding into the image `ufo_240.jpg` which is then decoded, and then removed from the image.
72 |
73 | ```python
74 | from trustmark import TrustMark
75 | from PIL import Image
76 |
77 | # init
78 | tm=TrustMark(verbose=True, model_type='Q') # or try P
79 |
80 | # encoding example
81 | cover = Image.open('images/ufo_240.jpg').convert('RGB')
82 | tm.encode(cover, 'mysecret').save('ufo_240_Q.png')
83 |
84 | # decoding example
85 | cover = Image.open('images/ufo_240_Q.png').convert('RGB')
86 | wm_secret, wm_present, wm_schema = tm.decode(cover)
87 |
88 | if wm_present:
89 | print(f'Extracted secret: {wm_secret}')
90 | else:
91 | print('No watermark detected')
92 |
93 | # removal example
94 | stego = Image.open('images/ufo_240_Q.png').convert('RGB')
95 | im_recover = tm.remove_watermark(stego)
96 | im_recover.save('images/recovered.png')
97 | ```
98 |
99 | ## GPU setup
100 |
101 | TrustMark runs well on CPU hardware.
102 |
103 | To leverage GPU compute for the PyTorch implementation on Ubuntu Linux, first install Conda, then use the following commands to install:
104 |
105 | ```sh
106 | conda create --name trustmark python=3.10
107 | conda activate trustmark
108 | conda install pytorch cudatoolkit=12.8 -c pytorch -c conda-forge
109 | pip install torch==2.1.2 torchvision==0.16.2 -f https://download.pytorch.org/whl/torch_stable.html
110 | pip install .
111 | ```
112 |
113 | For the JavaScript implementation, a Chromium browser automatically uses WebGPU, if available.
114 |
115 | ## Data schema
116 |
117 | TrustMark encodes a payload (the watermark data embedded within the image) of 100 bits.
118 | You can configure an error correction level over the raw 100 bits of payload to maintain reliability under transformations or noise.
119 |
120 | In payload encoding, the version bits comprise two reserved (unused) bits, and two bits encoding an integer value 0-3 that specifies the data schema as follows:
121 | - 0: BCH_SUPER
122 | - 1: BCH_5
123 | - 2: BCH_4
124 | - 3: BCH_3
125 |
126 | For more details and information on configuring the encoding mode in Python, see [Configuring TrustMark](python/CONFIG.md).
127 |
128 | ## Citation
129 |
130 | If you find this work useful, please cite the repository and/or TrustMark paper as follows:
131 |
132 | ```
133 | @article{trustmark,
134 | title={Trustmark: Universal Watermarking for Arbitrary Resolution Images},
135 | author={Bui, Tu and Agarwal, Shruti and Collomosse, John},
136 | journal = {ArXiv e-prints},
137 | archivePrefix = "arXiv",
138 | eprint = {2311.18297},
139 | year = 2023,
140 | month = nov
141 | }
142 | ```
143 |
144 | ## License
145 |
146 | This package is is distributed under the terms of the [MIT license](https://github.com/adobe/trustmark/blob/main/LICENSE).
147 |
--------------------------------------------------------------------------------
/c2pa/README.md:
--------------------------------------------------------------------------------
1 | # Using TrustMark with C2PA
2 |
3 | Open standards such as Content Credentials, developed by the [Coalition for Content Provenance and Authenticity(C2PA)](https://c2pa.org/), describe ways to encode information about an image’s history or _provenance_, such as how and when it was made. This information is usually carried within the image’s metadata.
4 |
5 | ## Durable Content Credentials
6 |
7 | C2PA manifest data can be accidentally removed when the image is shared through platforms that do not yet support the standard. If a copy of the manifest data is retained in a database, the TrustMark identifier carried inside the watermark can be used as a key to look up that information from the database. This is referred to as a [_Durable Content Credential_](https://contentauthenticity.org/blog/durable-content-credentials) and the technical term for the identifier is a _soft binding_.
8 |
9 | To create a soft binding, TrustMark encodes a random identifier via one of the encoding types in the [data schema](../README.md#data-schema). For example, [`c2pa/c2pa_watermark_example.py`](https://github.com/adobe/trustmark/blob/main/c2pa/c2pa_watermark_example.py) shows how to reflect the identifier within the C2PA manifest using a _soft binding assertion_.
10 |
11 | For more information, see the [FAQ](../FAQ.md#how-does-trustmark-align-with-provenance-standards-such-as-the-c2pa).
12 |
13 | ## Signpost watermark
14 |
15 | TrustMark [coexists well with most other image watermarks](https://arxiv.org/abs/2501.17356) and so can be used as a _signpost_ to indicate the co-presence of another watermarking technology. This can be helpful, sinace as an open technology, TrustMark can be used to indicate (signpost) which decoder to obtain and run on an image to decode a soft binding identifier for C2PA.
16 |
17 | In this mode the encoding should be `Encoding.BCH_SUPER` and the payload contain an integer identifier that describes the co-present watermark. The integer should be taken from the registry of C2PA-approved watermarks listed in this normative C2PA [softbinding-algorithms-list](https://github.com/c2pa-org/softbinding-algorithms-list) repository.
18 |
--------------------------------------------------------------------------------
/c2pa/c2pa_watermark_example.py:
--------------------------------------------------------------------------------
1 | import random,os
2 | from PIL import Image
3 | import json
4 | import struct
5 |
6 | from trustmark import TrustMark
7 |
8 | from PIL import Image
9 |
10 | TM_SCHEMA_CODE=TrustMark.Encoding.BCH_4
11 |
12 | def uuidgen(bitlen):
13 |
14 | id = ''.join(random.choice('01') for _ in range(bitlen))
15 | return id
16 |
17 |
18 | def embed_watermark(img_in, img_out, watermarkID, tm):
19 | cover = Image.open(img_in)
20 | rgb=cover.convert('RGB')
21 | has_alpha=cover.mode== 'RGBA'
22 | if (has_alpha):
23 | alpha=cover.split()[-1]
24 | encoded=tm.encode(rgb, watermarkID, MODE='binary')
25 | params={
26 | "exif":cover.info.get('exif'),
27 | "icc_profile":cover.info.get('icc_profile'),
28 | "dpi":cover.info.get('dpi')
29 | }
30 | not_none_params = {k:v for k, v in params.items() if v is not None}
31 | encoded.save(img_out, **not_none_params)
32 |
33 | def build_manifest(watermarkID, img_in):
34 |
35 | assertions=[]
36 | assertions.append(build_softbinding('com.adobe.trustmark.Q',str(TM_SCHEMA_CODE)+"*"+watermarkID))
37 | actions=[]
38 | act=dict()
39 | act['action']='c2pa.watermarked'
40 | actions.append(act)
41 |
42 | manifest=dict()
43 | manifest['claim_generator']="python_trustmark/c2pa"
44 | manifest['title']="Watermarked Image"
45 | manifest['thumbnail']=dict()
46 | manifest['ingredient_paths']=[img_in]
47 |
48 | ext=img_in.split('.')[-1]
49 | manifest['thumbnail']['format']="image/"+ext
50 | manifest['thumbnail']['identifier']=img_in
51 | manifest['assertions']=assertions
52 | manifest['actions']=actions
53 |
54 | return manifest
55 |
56 | def build_softbinding(alg,val):
57 | sba=dict()
58 | sba['label']='c2pa.soft-binding'
59 | sba['data']=dict()
60 | sba['data']['alg']=alg
61 | sba['data']['blocks']=list()
62 | blk=dict()
63 | blk['scope']=dict()
64 | blk['value']=val
65 | sba['data']['blocks'].append(blk)
66 | return sba
67 |
68 | def manifest_add_signing(mf):
69 | mf['alg']='es256'
70 | mf['ta_url']='http://timestamp.digicert.com'
71 | mf['private_key']='keys/es256_private.key'
72 | mf['sign_cert']='keys/es256_certs.pem'
73 | return mf
74 |
75 | def manifest_add_creator(mf,name):
76 | cwa=dict()
77 | cwa['label']='stds.schema-org.CreativeWork'
78 | cwa['data']=dict()
79 | cwa['data']['@context']='https://schema.org'
80 | cwa['data']['@type']='CreativeWork'
81 | author=dict()
82 | author['@type']='Person'
83 | author['name']=name
84 | cwa['data']['author']=[author]
85 | mf['assertions'].append(cwa)
86 | return mf
87 |
88 |
89 | def main() :
90 |
91 | img_in='example.jpg'
92 | img_out='example_wm.jpg'
93 | img_out_signed='example_wm_signed.jpg'
94 |
95 | # Generate a random watermark ID
96 | tm=TrustMark(verbose=True, model_type='Q', encoding_type=TM_SCHEMA_CODE)
97 | bitlen=tm.schemaCapacity()
98 | id=uuidgen(bitlen)
99 |
100 | # Encode watermark
101 | embed_watermark(img_in, img_out, id, tm)
102 |
103 | # Build manifest
104 | mf=build_manifest(id, img_in)
105 | mf=manifest_add_creator(mf,"Walter Mark")
106 | mf=manifest_add_signing(mf)
107 |
108 |
109 | fp=open('manifest.json','wt')
110 | fp.write(json.dumps(mf, indent=4))
111 | fp.close()
112 | os.system('c2patool '+img_out+' -m manifest.json -f -o '+img_out_signed)
113 |
114 |
115 |
116 | # Check watermark present
117 | stego = Image.open(img_out_signed).convert('RGB')
118 | wm_id, wm_present, wm_schema = tm.decode(stego, 'binary')
119 | if wm_present:
120 | print('Watermark detected in signed image')
121 | if wm_id==id:
122 | print('Watermark is correct')
123 | print(id)
124 | else:
125 | print('Watermark does not match!')
126 | print(id)
127 | print(wm_id)
128 | else:
129 | print('No watermark detected!')
130 |
131 |
132 |
133 | if __name__ == "__main__":
134 | main()
135 |
136 |
137 |
--------------------------------------------------------------------------------
/c2pa/example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/c2pa/example.jpg
--------------------------------------------------------------------------------
/c2pa/example_wm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/c2pa/example_wm.jpg
--------------------------------------------------------------------------------
/c2pa/example_wm_signed.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/c2pa/example_wm_signed.jpg
--------------------------------------------------------------------------------
/c2pa/keys/README.md:
--------------------------------------------------------------------------------
1 | # CAI Test Signing Keys
2 |
3 | These public test keys are updated from time to time, and taken December 2023 from https://github.com/contentauth/c2patool/tree/main/sample
4 |
5 | Please source new keys if expired.
6 |
7 | These are not secrets.
8 |
--------------------------------------------------------------------------------
/c2pa/keys/es256_certs.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIChzCCAi6gAwIBAgIUcCTmJHYF8dZfG0d1UdT6/LXtkeYwCgYIKoZIzj0EAwIw
3 | gYwxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJU29tZXdoZXJl
4 | MScwJQYDVQQKDB5DMlBBIFRlc3QgSW50ZXJtZWRpYXRlIFJvb3QgQ0ExGTAXBgNV
5 | BAsMEEZPUiBURVNUSU5HX09OTFkxGDAWBgNVBAMMD0ludGVybWVkaWF0ZSBDQTAe
6 | Fw0yMjA2MTAxODQ2NDBaFw0zMDA4MjYxODQ2NDBaMIGAMQswCQYDVQQGEwJVUzEL
7 | MAkGA1UECAwCQ0ExEjAQBgNVBAcMCVNvbWV3aGVyZTEfMB0GA1UECgwWQzJQQSBU
8 | ZXN0IFNpZ25pbmcgQ2VydDEZMBcGA1UECwwQRk9SIFRFU1RJTkdfT05MWTEUMBIG
9 | A1UEAwwLQzJQQSBTaWduZXIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQPaL6R
10 | kAkYkKU4+IryBSYxJM3h77sFiMrbvbI8fG7w2Bbl9otNG/cch3DAw5rGAPV7NWky
11 | l3QGuV/wt0MrAPDoo3gwdjAMBgNVHRMBAf8EAjAAMBYGA1UdJQEB/wQMMAoGCCsG
12 | AQUFBwMEMA4GA1UdDwEB/wQEAwIGwDAdBgNVHQ4EFgQUFznP0y83joiNOCedQkxT
13 | tAMyNcowHwYDVR0jBBgwFoAUDnyNcma/osnlAJTvtW6A4rYOL2swCgYIKoZIzj0E
14 | AwIDRwAwRAIgOY/2szXjslg/MyJFZ2y7OH8giPYTsvS7UPRP9GI9NgICIDQPMKrE
15 | LQUJEtipZ0TqvI/4mieoyRCeIiQtyuS0LACz
16 | -----END CERTIFICATE-----
17 | -----BEGIN CERTIFICATE-----
18 | MIICajCCAg+gAwIBAgIUfXDXHH+6GtA2QEBX2IvJ2YnGMnUwCgYIKoZIzj0EAwIw
19 | dzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQHDAlTb21ld2hlcmUx
20 | GjAYBgNVBAoMEUMyUEEgVGVzdCBSb290IENBMRkwFwYDVQQLDBBGT1IgVEVTVElO
21 | R19PTkxZMRAwDgYDVQQDDAdSb290IENBMB4XDTIyMDYxMDE4NDY0MFoXDTMwMDgy
22 | NzE4NDY0MFowgYwxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJ
23 | U29tZXdoZXJlMScwJQYDVQQKDB5DMlBBIFRlc3QgSW50ZXJtZWRpYXRlIFJvb3Qg
24 | Q0ExGTAXBgNVBAsMEEZPUiBURVNUSU5HX09OTFkxGDAWBgNVBAMMD0ludGVybWVk
25 | aWF0ZSBDQTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABHllI4O7a0EkpTYAWfPM
26 | D6Rnfk9iqhEmCQKMOR6J47Rvh2GGjUw4CS+aLT89ySukPTnzGsMQ4jK9d3V4Aq4Q
27 | LsOjYzBhMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQDAgGGMB0GA1UdDgQW
28 | BBQOfI1yZr+iyeUAlO+1boDitg4vazAfBgNVHSMEGDAWgBRembiG4Xgb2VcVWnUA
29 | UrYpDsuojDAKBggqhkjOPQQDAgNJADBGAiEAtdZ3+05CzFo90fWeZ4woeJcNQC4B
30 | 84Ill3YeZVvR8ZECIQDVRdha1xEDKuNTAManY0zthSosfXcvLnZui1A/y/DYeg==
31 | -----END CERTIFICATE-----
32 |
33 |
--------------------------------------------------------------------------------
/c2pa/keys/es256_private.key:
--------------------------------------------------------------------------------
1 | -----BEGIN PRIVATE KEY-----
2 | MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgfNJBsaRLSeHizv0m
3 | GL+gcn78QmtfLSm+n+qG9veC2W2hRANCAAQPaL6RkAkYkKU4+IryBSYxJM3h77sF
4 | iMrbvbI8fG7w2Bbl9otNG/cch3DAw5rGAPV7NWkyl3QGuV/wt0MrAPDo
5 | -----END PRIVATE KEY-----
6 |
--------------------------------------------------------------------------------
/c2pa/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "claim_generator": "python_trustmark/c2pa",
3 | "title": "Watermarked Image",
4 | "thumbnail": {
5 | "format": "image/jpg",
6 | "identifier": "example.jpg"
7 | },
8 | "ingredient_paths": [
9 | "example.jpg"
10 | ],
11 | "assertions": [
12 | {
13 | "label": "c2pa.soft-binding",
14 | "data": {
15 | "alg": "com.adobe.trustmark.Q",
16 | "blocks": [
17 | {
18 | "scope": {},
19 | "value": "2*00000010010100000100001111011011010011100010011101000010100000001110"
20 | }
21 | ]
22 | }
23 | },
24 | {
25 | "label": "stds.schema-org.CreativeWork",
26 | "data": {
27 | "@context": "https://schema.org",
28 | "@type": "CreativeWork",
29 | "author": [
30 | {
31 | "@type": "Person",
32 | "name": "Walter Mark"
33 | }
34 | ]
35 | }
36 | }
37 | ],
38 | "actions": [
39 | {
40 | "action": "c2pa.watermarked"
41 | }
42 | ],
43 | "alg": "es256",
44 | "ta_url": "http://timestamp.digicert.com",
45 | "private_key": "keys/es256_private.key",
46 | "sign_cert": "keys/es256_certs.pem"
47 | }
--------------------------------------------------------------------------------
/images/ghost.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ghost.png
--------------------------------------------------------------------------------
/images/ghost_P.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ghost_P.png
--------------------------------------------------------------------------------
/images/ghost_Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ghost_Q.png
--------------------------------------------------------------------------------
/images/ripley.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ripley.jpg
--------------------------------------------------------------------------------
/images/ripley_P.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ripley_P.png
--------------------------------------------------------------------------------
/images/ripley_Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ripley_Q.png
--------------------------------------------------------------------------------
/images/ufo_240.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ufo_240.jpg
--------------------------------------------------------------------------------
/images/ufo_240_P.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ufo_240_P.png
--------------------------------------------------------------------------------
/images/ufo_240_Q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/images/ufo_240_Q.png
--------------------------------------------------------------------------------
/js/README.md:
--------------------------------------------------------------------------------
1 | # TrustMark JavaScript example
2 |
3 | The [`js`](https://github.com/adobe/trustmark/tree/main/js) directory contains an example JavaScript implementation of decoding TrustMark watermarks embedded in images. It provides a minimal example of a client-side JavaScript application, would could be applied, for example, to a browser extension.
4 |
5 | ## Overview
6 |
7 | The example consists of key modules that handle image preprocessing, watermark detection, and data decoding. It supports the `Q` variant of the TrustMark watermarking schema.
8 |
9 | NOTE: The TrustMark JavaScript implementation only decodes watermarked images, in contrast to the Python implementation, which can both encode and decode.
10 |
11 | ## Components
12 |
13 | This example consists of a simple HTML file, `index.html` that loads two JavaScript files:
14 |
15 | - [`tm_watermark.js`](https://github.com/adobe/trustmark/blob/main/js/tm_watermark.js), the core module that handles watermark detection and decoding and defines key functions for processing images and extracting watermark data. The `modelConfigs` array specifies parameters such as the TrustMark model variant. As provided, the code checks for both Q and P variants.
16 | - [`tm_datalayer.js`](https://github.com/adobe/trustmark/blob/main/js/tm_datalayer.js) handles data decoding and schema-specific processing. It also implements error correction and interpretation of binary watermark data.
17 |
18 | If GPU compute is available (if you're using Google Chrome, check `chrome://gpu`), then the code will automatically use WebGPU to process the ONNX models. If you use WebGPU it will only run in a secure context, which means on localhost or an HTTPS link. You can start a local HTTPS server by running the `server.py` script and a suitable OpenSSL certificate in `server.pem`.
19 |
20 | ## Key parameters
21 |
22 | The desired TrustMark watermark variants for decoding are listed in the `modelConfigs` array at the top of `tm_watermark.js` for example, B, C, Q and P variants:
23 |
24 | ```js
25 | const modelConfigs = [
26 | { variantcode: 'Q', fname: 'decoder_Q', sessionVar: 'session_wmarkQ', resolution: 256, squarecrop: false },
27 | { variantcode: 'P', fname: 'decoder_P', sessionVar: 'session_wmarkP', resolution: 224, squarecrop: true },
28 | ];
29 | ```
30 |
31 | ## Run the example
32 |
33 | To run the example:
34 |
35 | 1. Start a local web server; for example, using Python:
36 | ```
37 | python -m http.server 8000
38 | ```
39 | 1. Open `index.html` in a browser to run the example.
40 | 1. Drag and drop images (for example from the provided [`images`](https://github.com/adobe/trustmark/tree/main/images) directory) onto the indicated area in the web page, which will display information if the image contains a TrustMark watermark (see example below).
41 |
42 | To use the code in your own project, simply include the two JavaScript files as usual.
43 |
44 | ## Example output
45 |
46 | For an image containing a TrustMark watermark:
47 |
48 | ```
49 | [4:12:48 PM] Decoding watermark...
50 | [4:12:48 PM] Watermark Found (BCH_5):
51 | 1101101111100111100111100101110001111100101100101111010000000
52 | C2PA Assertion:
53 | {
54 | "c2pa.soft-binding": {
55 | "alg": "com.adobe.trustmark.Q",
56 | "blocks": [
57 | {
58 | "scope": {},
59 | "value": "1*1101101111100111100111100101110001111100101100101111010000000"
60 | }
61 | ]
62 | }
63 | }
64 | ```
65 |
66 |
--------------------------------------------------------------------------------
/js/deps/ort-wasm-simd-threaded.jsep.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/js/deps/ort-wasm-simd-threaded.jsep.wasm
--------------------------------------------------------------------------------
/js/deps/ort-wasm-simd-threaded.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/js/deps/ort-wasm-simd-threaded.wasm
--------------------------------------------------------------------------------
/js/index.html:
--------------------------------------------------------------------------------
1 |
2 |
10 |
11 |
12 |
13 | TrustMark JS / Client-side Decode Demonstrator
14 |
15 |
16 |
17 |
18 |
124 |
125 |
126 |
127 | TrustMark JS / Client-side Decode Demonstrator
128 |
129 |
130 |
Drag & drop your image here
131 |
— or —
132 |
Browse Files
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
Watermark output will appear here...
142 |
143 |
144 |
145 |
272 |
273 |
274 |
275 |
276 |
--------------------------------------------------------------------------------
/js/tm_datalayer.js:
--------------------------------------------------------------------------------
1 | /*!
2 | * TrustMark Data Layer module
3 | * Copyright 2024 Adobe. All rights reserved.
4 | * Licensed under the MIT License.
5 | *
6 | * NOTICE: Adobe permits you to use, modify, and distribute this file in
7 | * accordance with the terms of the Adobe license agreement accompanying it.
8 | */
9 |
10 | // Initialize ECC engines for all supported versions
11 | let eccengine = [];
12 | for (let version = 0; version < 4; version++) {
13 | eccengine.push(DataLayer_GetECCEngine(version));
14 | }
15 |
16 | /**
17 | * Decodes the watermark data using the given ECC engine and schema.
18 | * Attempts fallback decoding with alternate schemas if the primary attempt fails.
19 | *
20 | * @param {Array} watermarkbool - The boolean array representing the watermark.
21 | * @param {Array} eccengine - Array of ECC engines for decoding.
22 | * @returns {Object} Decoded watermark data with schema and soft binding info.
23 | */
24 | function DataLayer_Decode(watermarkbool, eccengine, variant) {
25 | let version = DataLayer_GetVersion(watermarkbool);
26 | let databits = DataLayer_GetSchemaDataBits(version);
27 |
28 | let data = watermarkbool.slice(0, databits);
29 | let ecc = watermarkbool.slice(databits, 96);
30 |
31 | let dataobj = BCH_Decode(eccengine[version], data, ecc);
32 | dataobj.schema = DataLayer_GetSchemaName(version);
33 |
34 | if (!dataobj.valid) {
35 | // Attempt decoding with alternate schemas
36 | for (let alt = 0; alt < 3; alt++) {
37 | if (alt === version) continue;
38 |
39 | databits = DataLayer_GetSchemaDataBits(alt);
40 | data = watermarkbool.slice(0, databits);
41 | ecc = watermarkbool.slice(databits, 96);
42 |
43 | dataobj = BCH_Decode(eccengine[alt], data, ecc);
44 | dataobj.schema = DataLayer_GetSchemaName(alt);
45 | if (dataobj.valid) break;
46 | }
47 | }
48 |
49 | // Add soft binding information
50 | dataobj.softBindingInfo = formatSoftBindingData(dataobj.data_binary, version, variant);
51 | return dataobj;
52 | }
53 |
54 | /**
55 | * Interprets the watermark data in the context of C2PA.
56 | *
57 | * @param {Object} dataobj - The decoded watermark data object.
58 | * @param {number} version - The version of the schema.
59 | * @returns {Promise} Promise resolving with the updated data object.
60 | */
61 | function interpret_C2PA(dataobj, version) {
62 | return new Promise((resolve, reject) => {
63 | if (true) { // Placeholder for schema-specific logic
64 | fetchSoftBindingInfo(dataobj.data_binary)
65 | .then(softBindingInfo => {
66 | if (softBindingInfo) {
67 | dataobj.softBindingInfo = softBindingInfo;
68 | } else {
69 | console.warn("No soft binding info found.");
70 | }
71 | resolve(dataobj);
72 | })
73 | .catch(error => {
74 | console.error("Error fetching soft binding info:", error);
75 | reject(error);
76 | });
77 | } else {
78 | resolve(dataobj);
79 | }
80 | });
81 | }
82 |
83 | /**
84 | * Extracts the schema version from the last two bits of the watermark boolean array.
85 | *
86 | * @param {Array} watermarkbool - The boolean array representing the watermark.
87 | * @returns {number} The schema version as an integer.
88 | */
89 | function DataLayer_GetVersion(watermarkbool) {
90 | watermarkbool = watermarkbool.slice(-2);
91 | return watermarkbool[0] * 2 + watermarkbool[1];
92 | }
93 |
94 | /**
95 | * Retrieves the ECC engine for the given schema version.
96 | *
97 | * @param {number} version - The schema version.
98 | * @returns {Object} The corresponding ECC engine.
99 | */
100 | function DataLayer_GetECCEngine(version) {
101 | switch (version) {
102 | case 0:
103 | return BCH(8, 137);
104 | case 1:
105 | return BCH(5, 137);
106 | case 2:
107 | return BCH(4, 137);
108 | case 3:
109 | return BCH(3, 137);
110 | default:
111 | return -1;
112 | }
113 | }
114 |
115 | /**
116 | * Retrieves the number of data bits for the given schema version.
117 | *
118 | * @param {number} version - The schema version.
119 | * @returns {number} The number of data bits.
120 | */
121 | function DataLayer_GetSchemaDataBits(version) {
122 | switch (version) {
123 | case 0:
124 | return 40;
125 | case 1:
126 | return 61;
127 | case 2:
128 | return 68;
129 | case 3:
130 | return 75;
131 | default:
132 | console.error("Invalid schema version");
133 | return 0;
134 | }
135 | }
136 |
137 | /**
138 | * Retrieves the name of the schema for the given version.
139 | *
140 | * @param {number} version - The schema version.
141 | * @returns {string} The schema name.
142 | */
143 | function DataLayer_GetSchemaName(version) {
144 | switch (version) {
145 | case 0:
146 | return "BCH_SUPER";
147 | case 1:
148 | return "BCH_5";
149 | case 2:
150 | return "BCH_4";
151 | case 3:
152 | return "BCH_3";
153 | default:
154 | return "Invalid";
155 | }
156 | }
157 |
158 | /**
159 | * Formats the encoded watermark data into a structured JSON object.
160 | *
161 | * @param {Array} encodedData - The binary data representing the watermark.
162 | * @param {number} version - The schema version.
163 | * @returns {Object|null} Formatted JSON object or null in case of errors.
164 | */
165 | function formatSoftBindingData(encodedData, version, variant) {
166 | try {
167 | const binaryString = Array.isArray(encodedData)
168 | ? encodedData.join('')
169 | : String(encodedData);
170 |
171 | return {
172 | "c2pa.soft-binding": {
173 | "alg": `com.adobe.trustmark.${variant}`,
174 | "blocks": [
175 | {
176 | "scope": {},
177 | "value": `${version}*${binaryString}`
178 | }
179 | ]
180 | }
181 | };
182 | } catch (error) {
183 | console.error("Error formatting soft binding data:", error);
184 | return null;
185 | }
186 | }
187 |
188 |
--------------------------------------------------------------------------------
/js/tm_watermark.js:
--------------------------------------------------------------------------------
1 | /*!
2 | * TrustMark JS Watermarking Decoder Module
3 | * Copyright 2024 Adobe. All rights reserved.
4 | * Licensed under the MIT License.
5 | *
6 | * NOTICE: Adobe permits you to use, modify, and distribute this file in
7 | * accordance with the terms of the Adobe license agreement accompanying it.
8 | */
9 |
10 | // Source for ONNX model binaries
11 | const MODEL_BASE_URL = "https://cc-assets.netlify.app/watermarking/trustmark-models/";
12 |
13 | // List all watermark models for decoding
14 | const modelConfigs = [
15 | { variantcode: 'Q', fname: 'decoder_Q', sessionVar: 'session_wmarkQ', resolution: 256, squarecrop: false },
16 | { variantcode: 'P', fname: 'decoder_P', sessionVar: 'session_wmarkP', resolution: 224, squarecrop: true },
17 | ];
18 |
19 | const sessions = {};
20 | let session_resize;
21 |
22 | // Load model immediately
23 | (async () => {
24 | for (const config of modelConfigs) {
25 | let startTime = new Date();
26 | try {
27 | sessions[config.sessionVar] = await ort.InferenceSession.create(`${MODEL_BASE_URL}${config.fname}.onnx`, { executionProviders: ['webgpu'] });
28 | let timeElapsed = new Date() - startTime;
29 | console.log(`${config.fname} model loaded in ${timeElapsed / 1000} seconds`);
30 | } catch (error) {
31 | console.error(`Could not load ${config.fname} watermark decoder model`, error);
32 | }
33 | }
34 | let startTime = new Date();
35 | try {
36 | session_resize = await ort.InferenceSession.create(`${MODEL_BASE_URL}resizer.onnx`, { executionProviders: ['wasm'] }); // cannot use GPU for this due to lack of antialias
37 | let timeElapsed = new Date() - startTime;
38 | console.log(`Image downscaler model loaded in ${timeElapsed / 1000} seconds`);
39 | }
40 | catch (error) {
41 | console.log('Could not load image downscaler model', error);
42 | console.log(error)
43 | }
44 | })();
45 |
46 |
47 | /* WebGPU will fail silently and intermittently if multiple concurrent inference calls are made
48 | this routine ensures sequential calling from multiple threads. Note this simple JS demo is single threaded. */
49 | let inferenceLock = false;
50 |
51 | async function safeRunInference(session, feed) {
52 | while (inferenceLock) {
53 | // Wait for any ongoing inference
54 | await new Promise(resolve => setTimeout(resolve, 30));
55 | }
56 |
57 | inferenceLock = true; // Lock inference
58 | try {
59 | return await session.run(feed); // Run the inference
60 | } catch (error) {
61 | console.error("Inference error:", error); // Log any error
62 | throw error; // Rethrow for further debugging
63 | } finally {
64 | inferenceLock = false; // Unlock after inference
65 | }
66 | }
67 |
68 |
69 | /**
70 | * Converts an image URL to a tensor suitable for processing.
71 | * @param {string} imageUrl - The URL of the image to load.
72 | * @returns {Promise} The processed tensor.
73 | */
74 | async function loadImageAsTensor(imageUrl) {
75 | const img = new Image();
76 | img.src = imageUrl;
77 |
78 | return new Promise((resolve, reject) => {
79 | img.onload = () => {
80 | const canvas = document.createElement('canvas');
81 | const ctx = canvas.getContext('2d');
82 | canvas.width = img.width;
83 | canvas.height = img.height;
84 | ctx.drawImage(img, 0, 0);
85 | const imgData = ctx.getImageData(0, 0, img.width, img.height);
86 |
87 | const { data, width, height } = imgData;
88 | const totalPixels = width * height;
89 | const imageTensor = new Float32Array(totalPixels * 3);
90 |
91 | let j = 0;
92 | const page = width * height;
93 | const twopage = 2 * page;
94 |
95 | for (let i = 0; i < totalPixels; i++) {
96 | const index = i * 4;
97 | imageTensor[j] = data[index] / 255.0; // Red channel
98 | imageTensor[j + page] = data[index + 1] / 255.0; // Green channel
99 | imageTensor[j + twopage] = data[index + 2] / 255.0; // Blue channel
100 | j++;
101 | }
102 |
103 | resolve(new ort.Tensor('float32', imageTensor, [1, 3, height, width]));
104 | };
105 |
106 | img.onerror = () => reject("Failed to load image");
107 | });
108 | }
109 |
110 | /**
111 | * Computes scale factors for image resizing with precision.
112 | * @param {Array} targetDims - Target dimensions for the image.
113 | * @param {Array} inputDims - Input tensor dimensions.
114 | * @returns {Float32Array} Scale factors as a tensor.
115 | */
116 | function computeScalesFixed(targetDims, inputDims) {
117 | const [batch, channels, height, width] = inputDims;
118 | const [targetHeight, targetWidth] = targetDims;
119 |
120 | function computeScale(originalSize, targetSize) {
121 | let minScale = targetSize / originalSize;
122 | let maxScale = (targetSize + 1) / originalSize;
123 | let scale;
124 | let adjustedSize;
125 |
126 | const tolerance = 1e-12;
127 | let iterations = 0;
128 | const maxIterations = 100;
129 |
130 | while (iterations < maxIterations) {
131 | scale = (minScale + maxScale) / 2;
132 | adjustedSize = Math.floor(originalSize * scale + tolerance);
133 |
134 | if (adjustedSize < targetSize) {
135 | minScale = scale;
136 | } else if (adjustedSize > targetSize) {
137 | maxScale = scale;
138 | } else {
139 | break; // Found the correct scale
140 | }
141 |
142 | iterations++;
143 | }
144 |
145 | return scale;
146 | }
147 |
148 | const scaleH = computeScale(height, targetHeight);
149 | const scaleW = computeScale(width, targetWidth);
150 |
151 | return new Float32Array([1.0, 1.0, scaleH, scaleW]);
152 | }
153 |
154 | /**
155 | * Resizes the image tensor to a square size suitable for watermark decoding.
156 | * @param {ort.Tensor} inputTensor - The input image tensor.
157 | * @param {number} targetSize - The target size for resizing.
158 | * @returns {Promise} The resized tensor.
159 | */
160 | async function runResizeModelSquare(inputTensor, targetSize, force_square) {
161 | try {
162 | const inputDims = inputTensor.dims; // Get dimensions of the input tensor
163 | const [batch, channels, height, width] = inputDims;
164 |
165 | // Compute the aspect ratio
166 | const aspectRatio = width / height;
167 | const lscape= (aspectRatio>=1.0);
168 |
169 | let croppedTensor = inputTensor;
170 | let cropWidth = width;
171 | let cropHeight = height;
172 |
173 | // If the aspect ratio is greater than 2.0, we need to crop the center square
174 | if (lscape && (aspectRatio > 2.0 || force_square)) {
175 | cropWidth = height; // Take a square from the width
176 | const offsetX = Math.floor((width - cropWidth) / 2); // Horizontal center crop
177 | croppedTensor = await cropTensor(inputTensor, offsetX, 0, cropWidth, height);
178 | }
179 |
180 | if (!lscape && (aspectRatio < 0.5 || force_square)) {
181 | cropHeight = width; // Take a square from the height
182 | const offsetY = Math.floor((height - cropHeight) / 2); // Vertical center crop
183 | croppedTensor = await cropTensor(inputTensor, 0, offsetY, width, cropHeight);
184 | }
185 |
186 | // After cropping, resize the tensor to the target size
187 | const targetDims = [targetSize, targetSize];
188 | const scales = computeScalesFixed(targetDims, [batch, channels, cropHeight, cropWidth]);
189 | const scalesTensor = new ort.Tensor('float32', scales, [4]);
190 |
191 | // Prepare the target size tensor
192 | const targetSizeTensor = new ort.Tensor('int64', new BigInt64Array([BigInt(targetSize)]), [1]);
193 |
194 | // Set up the feeds for the model
195 | const feeds = {
196 | 'X': croppedTensor, // Cropped image tensor
197 | 'scales': scalesTensor, // Scales tensor
198 | 'target_size': targetSizeTensor // Dynamic target size tensor
199 | };
200 |
201 | const results = await session_resize.run(feeds);
202 | return results['Y'];
203 |
204 | } catch (error) {
205 | console.error('Error during resizing:', error);
206 | return null;
207 | }
208 | }
209 |
210 | // Helper function to crop the tensor
211 | async function cropTensor(inputTensor, offsetX, offsetY, cropWidth, cropHeight) {
212 | const [batch, channels, height, width] = inputTensor.dims;
213 | const croppedData = new Float32Array(batch * channels * cropWidth * cropHeight);
214 | const inputData = inputTensor.data;
215 |
216 | let k = 0;
217 | for (let c = 0; c < channels; c++) {
218 | for (let y = 0; y < cropHeight; y++) {
219 | for (let x = 0; x < cropWidth; x++) {
220 | const srcIndex = c * width * height + (y + offsetY) * width + (x + offsetX);
221 | croppedData[k++] = inputData[srcIndex];
222 | }
223 | }
224 | }
225 |
226 | return new ort.Tensor('float32', croppedData, [batch, channels, cropHeight, cropWidth]);
227 | }
228 |
229 |
230 | /**
231 | * Decodes the watermark from the processed image tensor.
232 | * @param {string} base64Image - Base64 representation of the image.
233 | * @returns {Promise} Decoded watermark data.
234 | */
235 | async function runwmark(base64Image) {
236 |
237 | let watermarks=[]
238 | let watermarks_present=[];
239 | try {
240 |
241 | const inputTensor = await loadImageAsTensor(base64Image);
242 |
243 | for (const config of modelConfigs) {
244 |
245 | const session = sessions[config.sessionVar];
246 | if (!session) {
247 | console.error(`Session for ${config.fname} not loaded, skipping.`);
248 | continue;
249 | }
250 |
251 |
252 | const resizedTensorWM = await runResizeModelSquare(inputTensor, config.resolution, config.squarecrop);
253 | if (!resizedTensorWM) throw new Error("Failed to resize tensor for watermark detection.");
254 |
255 | const feeds = { image: resizedTensorWM };
256 |
257 | let startTime = new Date();
258 | const results = await safeRunInference(session, feeds);
259 |
260 | const watermarkFloat = results['output']['cpuData'];
261 | const watermarkBool = watermarkFloat.map((v) => v >= 0);
262 |
263 | const dataObj = DataLayer_Decode(watermarkBool, eccengine, config.variantcode);
264 | console.log(`Watermark model inference in ${(new Date() - startTime)} milliseconds`);
265 |
266 | // Append results to arrays
267 | watermarks.push(dataObj);
268 | watermarks_present.push(dataObj.valid);
269 | }
270 |
271 | } catch (error) {
272 | console.error("Error in watermark decoding:", error);
273 | return { watermark_present: false, watermark: null, schema: null };
274 | }
275 |
276 | // Get first detected watermark (if many were)
277 | const firstValidIndex = watermarks_present.findIndex(isValid => isValid === true);
278 | let watermark;
279 | let watermark_present=false;
280 | if (firstValidIndex !== -1) {
281 | watermark = watermarks[firstValidIndex];
282 |
283 | return {
284 | watermark_present: watermark.valid,
285 | watermark: watermark.valid ? watermark.data_binary : null,
286 | schema: watermark.schema,
287 | c2padata: watermark.softBindingInfo,
288 | };
289 | }
290 | else {
291 | return {
292 | watermark_present: false,
293 | }
294 | }
295 |
296 |
297 | }
298 |
299 |
--------------------------------------------------------------------------------
/python/CONFIG.md:
--------------------------------------------------------------------------------
1 | # Configuring TrustMark
2 |
3 | ## Overview
4 |
5 | All watermarking algorithms trade off between three properties:
6 |
7 | - **Capacity (bits)**
8 | - **Robustness (to various transformations)**
9 | - **Visibility (of watermark)**
10 |
11 | This document explains how to configure TrustMark to tune these properties, however the default configuration for TrustMark (variant Q, 100% strength, BCH_5 error correction) is sufficient for most use cases.
12 |
13 | ## Model variant
14 |
15 | TrustMark has four model variants (**B**, **C**, **P**, and **Q**) that may be selected when instantiating TrustMark. All encode/decode calls on the object will use this variant.
16 |
17 | In general, we recommend using **P** or **Q**:
18 | - **P** is useful for creative applications where very high visual quality is required.
19 | - **Q** is a good all-rounder and is the default.
20 |
21 | > **Note:** Images encoded with one model variant cannot be decoded with another.
22 |
23 | | Variant | Typical PSNR | Model Size Enc/Dec (MB) | Description |
24 | |---------|--------------|-------------------------|-----------------------------------------------------------------------------------------------------|
25 | | **Q** | 43-45 | 17/45 | Default (**Q**uality). Good trade-off between robustness and imperceptibility. Uses ResNet-50 decoder. |
26 | | **B** | 43-45 | 17/45 | (**B**eta). Very similar to Q, included mainly for reproducing the paper. Uses ResNet-50 decoder. |
27 | | **C** | 38-39 | 17/21 | (**C**ompact). Uses a ResNet-18 decoder (smaller model size). Slightly lower visual quality. |
28 | | **P** | 48-50 | 16/45 | (**P**erceptual). Very high visual quality and good robustness. ResNet-50 decoder trained with much higher weight on perceptual loss (see paper). |
29 |
30 | ## Watermark strength
31 |
32 | Set the optional `WM_STRENGTH` parameter when encoding (at runtime).
33 | Its default value is `1.0`, and changing it provides a trade-off between **robustness** and **visibility**:
34 |
35 | - Raising its value (for example, to 1.5) improves robustness (so, for example, the watermark survives printing) but increases the likelihood of ripple artifacts.
36 | - Lowering its value (for example, to 0.8) reduces any likelihood of artifacts but compromises on robustness; however it still survives lower noise, screenshotting, or social media.
37 |
38 | For example:
39 |
40 | ```python
41 | encoded_image = tm.encode(input_image, payload="example", WM_STRENGTH=1.5)
42 | ```
43 |
44 | ## Error correction level
45 |
46 | TrustMark encodes a payload (the watermark data embedded within the image) of 100 bits.
47 | The data schema implemented in `python/datalayer.py` enables you to choose an error correction level over the raw 100 bits of payload to maintain reliability under transformations or noise.
48 |
49 | ### Encoding modes
50 |
51 | Set the error correction level using one of the four encoding modes.
52 |
53 | The following table describes TrustMark's encoding modes:
54 |
55 | | Encoding | Protected payload | Number of bit flips allowed |
56 | |----------|-------------------|-----------------------------|
57 | | `Encoding.BCH_5` | 61 bits (+ 35 ECC bits) | 5 |
58 | | `Encoding.BCH_4` | 68 bits (+ 28 ECC bits) | 4 |
59 | | `Encoding.BCH_3` | 75 bits (+ 21 ECC bits) | 3 |
60 | | `Encoding.BCH_SUPER` | 40 bits (+ 56 ECC bits) | 8 |
61 |
62 | Specify the mode when you instantiate the encoder, as follows:
63 |
64 | ```py
65 | tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.)
66 | ```
67 |
68 | Where `` is `BCH_5`, `BCH_4`, `BCH_3`, or `BCH_SUPER`.
69 |
70 | For example:
71 |
72 | ```py
73 | tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.BCH_5)
74 | ```
75 |
76 | The decoder automatically detects the data schema in a watermark, so you can choose the level of robustness that best suits your use case.
77 |
78 | Selecting the model and strength implicitly selects the level of robustness and visibility of the watermark. If you have reduced robustness for lower visibility, you can regain some robustness by increasing error correction (at the cost of payload capacity). Note that even 40 bits gives a key space of around one trillion.
79 |
80 | ## Center cropping
81 |
82 | TrustMark generates residuals at 256 x 256 and then scales/blends them into the original image. Several derivative papers have adopted this universal resolution-scaling trick.
83 |
84 | - If the original image is extremely long/thin (aspect ratio beyond 2:1), the residual watermark will degrade when scaled.
85 | - TrustMark addresses this by automatically center-cropping the image to a square if the aspect ratio exceeds 2.0. For example, for a 1000 x 200 image, only a 200 x 200 region in the center carries the watermark.
86 | - The aspect ratio limit can be overridden via the ASPECT_RATIO_LIM parameter. Setting it to 1.0 always forces center-crop behavior (useful for content platforms that square-crop images). This is the default when using model variant **P**.
87 |
88 | ## Do not concentrate watermarks
89 |
90 | Visual quality is often measured via **PSNR** (Peak Signal-to-Noise Ratio), but PSNR does not perfectly correlate with human perception of watermark visibility.
91 |
92 | Some derivative works have tried to improve PSNR by "zero padding" or concentrating the watermark into only the central region, effectively altering fewer pixels and artificially raising the PSNR score. However, this approach **can increase human-visible artifacts** in the concentrated area.
93 |
94 | **Parameter:** CONCENTRATE_WM_REGION
95 | - Default is 100% (no zero padding).
96 | - If you set, for example, 80% zero padding, you might inflate PSNR by ~5 dB.
97 | - At 50% zero padding, PSNR might inflate by ~10 dB.
98 | - In some extreme cases, PSNR can reach 55–60 dB but at the cost of noticeable artifacts in that smaller region.
99 |
100 | **In summary**: While the functionality exists for comparison purposes, it’s not recommended for production. High concentration setting yields high PSNR but paradoxically more visible artifacts in the watermarked area.
101 |
--------------------------------------------------------------------------------
/python/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Adobe
2 | All Rights Reserved.
3 |
4 | NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | accordance with the terms of the license agreement accompanying it.
6 |
7 | MIT License
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy
10 | of this software and associated documentation files (the "Software"), to deal
11 | in the Software without restriction, including without limitation the rights
12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | copies of the Software, and to permit persons to whom the Software is
14 | furnished to do so, subject to the following conditions:
15 |
16 | The above copyright notice and this permission notice shall be included in all
17 | copies or substantial portions of the Software.
18 |
19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | SOFTWARE.
26 |
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | # TrustMark
2 |
3 | This repository contains the official, open source implementation of TrustMark watermarking for the Content Authenticity Initiative (CAI) as described in [**TrustMark - Universal Watermarking for Arbitrary Resolution Images**](https://arxiv.org/abs/2311.18297) (`arXiv:2311.18297`) by [Tu Bui](https://www.surrey.ac.uk/people/tu-bui)[^1], [Shruti Agarwal](https://research.adobe.com/person/shruti-agarwal/)[^2], and [John Collomosse](https://www.collomosse.com)[^1] [^2].
4 |
5 | [^1]: [DECaDE](https://decade.ac.uk/) Centre for the Decentralized Digital Economy, University of Surrey, UK.
6 |
7 | [^2]: [Adobe Research](https://research.adobe.com/), San Jose, CA.
8 |
9 | ## Overview
10 |
11 | This repository contains the following directories:
12 |
13 | - `/python`: Python implementation of TrustMark for encoding, decoding and removing image watermarks (using PyTorch). For information on configuring TrustMark in Python, see [Configuring TrustMark](CONFIG.md).
14 | - `/js`: Javascript implementation of TrustMark decoding of image watermarks (using ONNX). For more information, see [TrustMark - JavaScript implementation](../js/README.md).
15 | - `/c2pa`: Python example of how to indicate the presence of a TrustMark watermark in a C2PA manifest. For more information, see [Using TrustMark with C2PA](../c2pa/README.md).
16 |
17 | Model files (**ckpt** PyTorch file for Python and **onnx** ONNX file for JavaScript) are not packaged in this repository due to their size, but are downloaded upon first use. See the code for [URLs and md5 hashes](https://github.com/adobe/trustmark/blob/4ef0dde4abd84d1c6873e7c5024482f849db2c73/python/trustmark/trustmark.py#L30) for a direct download link.
18 |
19 | More information:
20 | - For answers to common questions, see the [FAQ](../FAQ.md).
21 | - For information on configuring TrustMark in Python, see [Configuring TrustMark](CONFIG.md).
22 |
23 | ## Installation
24 |
25 | ### Prerequisite
26 |
27 | You must have Python 3.8.5 or higher to use the TrustMark Python implementation.
28 |
29 | ### Installing from PyPI
30 |
31 | The easiest way to install TrustMark is from the [Python Package Index (PyPI)](https://pypi.org/project/trustmark/) by entering this command:
32 |
33 | ```
34 | pip install trustmark
35 | ```
36 |
37 | Alternatively, after you've cloned the repository, you can install from the `python` directory:
38 |
39 | ```
40 | cd trustmark/python
41 | pip install .
42 | ```
43 |
44 | ## Quickstart
45 |
46 | To get started quickly, run the `python/test.py` script that provides examples of watermarking several
47 | image files from the `images` directory.
48 |
49 | ### Run the example
50 |
51 | Run the example as follows:
52 |
53 | ```sh
54 | cd trustmark/python
55 | python test.py
56 | ```
57 |
58 | You'll see output like this:
59 |
60 | ```
61 | Initializing TrustMark watermarking with ECC using [cpu]
62 | Extracted secret: 1000000100001110000010010001011110010001011000100000100110110 (schema 1)
63 | PSNR = 50.357909
64 | No secret after removal
65 | ```
66 |
67 | ### Example script
68 |
69 | The `python/test.py` script provides examples of watermarking a JPEG photo, a JPEG GenAI image, and an RGBA PNG image. The example uses TrustMark variant Q to encode the word `mysecret` in ASCII7 encoding into the image `ufo_240.jpg` which is then decoded, and then removed from the image.
70 |
71 | ```python
72 | from trustmark import TrustMark
73 | from PIL import Image
74 |
75 | # init
76 | tm=TrustMark(verbose=True, model_type='Q') # or try P
77 |
78 | # encoding example
79 | cover = Image.open('images/ufo_240.jpg').convert('RGB')
80 | tm.encode(cover, 'mysecret').save('ufo_240_Q.png')
81 |
82 | # decoding example
83 | cover = Image.open('images/ufo_240_Q.png').convert('RGB')
84 | wm_secret, wm_present, wm_schema = tm.decode(cover)
85 |
86 | if wm_present:
87 | print(f'Extracted secret: {wm_secret}')
88 | else:
89 | print('No watermark detected')
90 |
91 | # removal example
92 | stego = Image.open('images/ufo_240_Q.png').convert('RGB')
93 | im_recover = tm.remove_watermark(stego)
94 | im_recover.save('images/recovered.png')
95 | ```
96 |
97 | ## GPU setup
98 |
99 | TrustMark runs well on CPU hardware.
100 |
101 | To leverage GPU compute for the PyTorch implementation on Ubuntu Linux, first install Conda, then use the following commands to install:
102 |
103 | ```sh
104 | conda create --name trustmark python=3.10
105 | conda activate trustmark
106 | conda install pytorch cudatoolkit=12.8 -c pytorch -c conda-forge
107 | pip install torch==2.1.2 torchvision==0.16.2 -f https://download.pytorch.org/whl/torch_stable.html
108 | pip install .
109 | ```
110 |
111 | For the JavaScript implementation, a Chromium browser automatically uses WebGPU, if available.
112 |
113 | ## Data schema
114 |
115 | TrustMark encodes a payload (the watermark data embedded within the image) of 100 bits.
116 | You can configure an error correction level over the raw 100 bits of payload to maintain reliability under transformations or noise.
117 |
118 | In payload encoding, the version bits comprise two reserved (unused) bits, and two bits encoding an integer value 0-3 that specifies the data schema as follows:
119 | - 0: BCH_SUPER
120 | - 1: BCH_5
121 | - 2: BCH_4
122 | - 3: BCH_3
123 |
124 | For more details and information on configuring the encoding mode in Python, see [Configuring TrustMark](CONFIG.md).
125 |
126 | ## Citation
127 |
128 | If you find this work useful, please cite the repository and/or TrustMark paper as follows:
129 |
130 | ```
131 | @article{trustmark,
132 | title={Trustmark: Universal Watermarking for Arbitrary Resolution Images},
133 | author={Bui, Tu and Agarwal, Shruti and Collomosse, John},
134 | journal = {ArXiv e-prints},
135 | archivePrefix = "arXiv",
136 | eprint = {2311.18297},
137 | year = 2023,
138 | month = nov
139 | }
140 | ```
141 |
142 | ## License
143 |
144 | This package is is distributed under the terms of the [MIT license](https://github.com/adobe/trustmark/blob/main/LICENSE).
145 |
--------------------------------------------------------------------------------
/python/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "trustmark"
3 | version = "0.8.0"
4 | description = "TrustMark: Universal Watermarking for Arbitrary Resolution Images"
5 | readme = "README.md"
6 | requires-python = ">=3.8.5"
7 |
8 | license = {file = "LICENSE.txt"}
9 |
10 | keywords = ["watermarking", "CAI", "watermark", "Content Authenticity Initiative", "provenance"] # Optional
11 |
12 | authors = [
13 | {name = "Shruti Agarwal", email = "shragarw@adobe.com" },
14 | {name = "John Collomosse", email = "collomos@adobe.com" }
15 | ]
16 |
17 | maintainers = [
18 | {name = "John Collomosse", email = "collomos@adobe.com" }
19 | ]
20 |
21 |
22 | classifiers = [
23 | "Development Status :: 5 - Production/Stable",
24 |
25 | # Indicate who your project is intended for
26 | "Intended Audience :: Science/Research",
27 |
28 | "License :: OSI Approved :: MIT License",
29 |
30 | "Programming Language :: Python :: 3"
31 | ]
32 |
33 | dependencies = [
34 | "omegaconf>=2.1",
35 | "pathlib>=1.0.1",
36 | "numpy>=1.20.0,<2.0.0",
37 | "torch>=2.1.2",
38 | "torchvision>=0.16.2",
39 | "lightning>=2.0",
40 | "six>=1.9",
41 | "einops>=0.4.0"
42 | ]
43 |
44 | [build-system]
45 | requires = ["setuptools>=43.0.0", "wheel"]
46 | build-backend = "setuptools.build_meta"
47 |
48 | [project.urls]
49 | Repository = "https://github.com/adobe/trustmark.git"
50 |
51 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.20.0,<2.0.0
2 | pathlib>=1.0.1
3 | torch>=2.1.2
4 | torchvision>=0.16.2
5 | lightning>=2.0
6 | omegaconf>=2.1
7 | six>=1.9
8 | einops>=0.4.0
9 |
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | from setuptools import setup, find_packages
9 |
10 | with open("README.md", "r") as f:
11 | long_description = f.read()
12 |
13 | setup(name='trustmark',
14 | version='0.8.0',
15 | python_requires='>=3.8.5',
16 | description='High fidelty image watermarking for the Content Authenticity Initiative (CAI)',
17 | url='https://github.com/adobe/trustmark',
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | author='John Collomosse',
21 | author_email='collomos@adobe.com',
22 | license='MIT License',
23 | packages=['trustmark','trustmark.KBNet'],
24 | package_data={'trustmark': ['**/*.yaml','**/*.ckpt','**/*.md']},
25 | include_package_data = True,
26 | install_requires=['omegaconf>=2.1',
27 | 'pathlib>=1.0.1',
28 | 'numpy>=1.20.0,<2.0.0',
29 | 'torch>=2.1.2',
30 | 'torchvision>=0.16.2',
31 | 'lightning>=2.0',
32 | 'six>=1.9',
33 | 'einops>=0.4.0'
34 | ],
35 |
36 | classifiers=[
37 | 'Development Status :: 5 - Production/Stable',
38 | 'Intended Audience :: Science/Research',
39 | 'Operating System :: OS Independent',
40 | 'License :: OSI Approved :: MIT License',
41 | 'Programming Language :: Python :: 3',],)
42 |
--------------------------------------------------------------------------------
/python/test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 |
9 | from trustmark import TrustMark
10 | from PIL import Image
11 | from pathlib import Path
12 | import math,random
13 | import numpy as np
14 |
15 |
16 | #EXAMPLE_FILE = '../images/ufo_240.jpg' # JPEG example
17 | #EXAMPLE_FILE = '../images/ghost.png' # PNG RGBA example
18 | EXAMPLE_FILE = '../images/ripley.jpg' # JPEG example
19 |
20 | # Available modes: Q=balance, P=high visual quality, C=compact decoder, B=base from paper
21 | MODE='P'
22 | tm=TrustMark(verbose=True, model_type=MODE, encoding_type=TrustMark.Encoding.BCH_5)
23 |
24 | # encoding example
25 | cover = Image.open(EXAMPLE_FILE)
26 | rgb=cover.convert('RGB')
27 | has_alpha=cover.mode== 'RGBA'
28 | if (has_alpha):
29 | alpha=cover.split()[-1]
30 |
31 | random.seed(1234)
32 | capacity=tm.schemaCapacity()
33 | bitstring=''.join([random.choice(['0', '1']) for _ in range(capacity)])
34 | encoded=tm.encode(rgb, bitstring, MODE='binary')
35 |
36 | if (has_alpha):
37 | encoded.putalpha(alpha)
38 | outfile=Path(EXAMPLE_FILE).stem+'_'+MODE+'.png'
39 | encoded.save(outfile, exif=cover.info.get('exif'), icc_profile=cover.info.get('icc_profile'), dpi=cover.info.get('dpi'))
40 |
41 | # decoding example
42 | stego = Image.open(outfile).convert('RGB')
43 | wm_secret, wm_present, wm_schema = tm.decode(stego, MODE='binary')
44 | if wm_present:
45 | print(f'Extracted secret: {wm_secret} (schema {wm_schema})')
46 | else:
47 | print('No watermark detected')
48 |
49 | # psnr (quality, higher is better)
50 | mse = np.mean(np.square(np.subtract(np.asarray(stego).astype(np.int16), np.asarray(rgb).astype(np.int16))))
51 | if mse > 0:
52 | PIXEL_MAX = 255.0
53 | psnr= 20 * math.log10(PIXEL_MAX) - 10 * math.log10(mse)
54 | print('PSNR = %f' % psnr)
55 |
56 | # removal
57 | stego = Image.open(outfile)
58 | rgb=stego.convert('RGB')
59 | has_alpha=stego.mode== 'RGBA'
60 | if (has_alpha):
61 | alpha=stego.split()[-1]
62 | im_recover = tm.remove_watermark(rgb)
63 | wm_secret, wm_present, wm_schema = tm.decode(im_recover)
64 | if wm_present:
65 | print(f'Extracted secret: {wm_secret} (schema {wm_schema})')
66 | else:
67 | print('No secret after removal')
68 | if (has_alpha):
69 | im_recover.putalpha(alpha)
70 | im_recover.save('recovered.png', exif=stego.info.get('exif'), icc_profile=stego.info.get('icc_profile'), dpi=stego.info.get('dpi'))
71 |
72 |
--------------------------------------------------------------------------------
/python/trustmark/KBNet/README.md:
--------------------------------------------------------------------------------
1 | Adapted from https://github.com/zhangyi-3/KBNet
2 |
3 | MIT License
4 |
5 | Copyright (c) 2022 Zhang Yi
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
--------------------------------------------------------------------------------
/python/trustmark/KBNet/arch_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | import math
9 | import torch
10 | from torch import nn as nn
11 | from torch.nn import functional as F
12 | from torch.nn import init as init
13 | from torch.nn.modules.batchnorm import _BatchNorm
14 |
15 |
16 | @torch.no_grad()
17 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
18 | """Initialize network weights.
19 |
20 | Args:
21 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
22 | scale (float): Scale initialized weights, especially for residual
23 | blocks. Default: 1.
24 | bias_fill (float): The value to fill bias. Default: 0
25 | kwargs (dict): Other arguments for initialization function.
26 | """
27 | if not isinstance(module_list, list):
28 | module_list = [module_list]
29 | for module in module_list:
30 | for m in module.modules():
31 | if isinstance(m, nn.Conv2d):
32 | init.kaiming_normal_(m.weight, **kwargs)
33 | m.weight.data *= scale
34 | if m.bias is not None:
35 | m.bias.data.fill_(bias_fill)
36 | elif isinstance(m, nn.Linear):
37 | init.kaiming_normal_(m.weight, **kwargs)
38 | m.weight.data *= scale
39 | if m.bias is not None:
40 | m.bias.data.fill_(bias_fill)
41 | elif isinstance(m, _BatchNorm):
42 | init.constant_(m.weight, 1)
43 | if m.bias is not None:
44 | m.bias.data.fill_(bias_fill)
45 |
46 |
47 | def make_layer(basic_block, num_basic_block, **kwarg):
48 | """Make layers by stacking the same blocks.
49 |
50 | Args:
51 | basic_block (nn.module): nn.module class for basic block.
52 | num_basic_block (int): number of blocks.
53 |
54 | Returns:
55 | nn.Sequential: Stacked blocks in nn.Sequential.
56 | """
57 | layers = []
58 | for _ in range(num_basic_block):
59 | layers.append(basic_block(**kwarg))
60 | return nn.Sequential(*layers)
61 |
62 |
63 | class ResidualBlockNoBN(nn.Module):
64 | """Residual block without BN.
65 |
66 | It has a style of:
67 | ---Conv-ReLU-Conv-+-
68 | |________________|
69 |
70 | Args:
71 | num_feat (int): Channel number of intermediate features.
72 | Default: 64.
73 | res_scale (float): Residual scale. Default: 1.
74 | pytorch_init (bool): If set to True, use pytorch default init,
75 | otherwise, use default_init_weights. Default: False.
76 | """
77 |
78 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
79 | super(ResidualBlockNoBN, self).__init__()
80 | self.res_scale = res_scale
81 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
82 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83 | self.relu = nn.ReLU(inplace=True)
84 |
85 | if not pytorch_init:
86 | default_init_weights([self.conv1, self.conv2], 0.1)
87 |
88 | def forward(self, x):
89 | identity = x
90 | out = self.conv2(self.relu(self.conv1(x)))
91 | return identity + out * self.res_scale
92 |
93 |
94 | class Upsample(nn.Sequential):
95 | """Upsample module.
96 |
97 | Args:
98 | scale (int): Scale factor. Supported scales: 2^n and 3.
99 | num_feat (int): Channel number of intermediate features.
100 | """
101 |
102 | def __init__(self, scale, num_feat):
103 | m = []
104 | if (scale & (scale - 1)) == 0: # scale = 2^n
105 | for _ in range(int(math.log(scale, 2))):
106 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
107 | m.append(nn.PixelShuffle(2))
108 | elif scale == 3:
109 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
110 | m.append(nn.PixelShuffle(3))
111 | else:
112 | raise ValueError(f'scale {scale} is not supported. '
113 | 'Supported scales: 2^n and 3.')
114 | super(Upsample, self).__init__(*m)
115 |
116 |
117 | def flow_warp(x,
118 | flow,
119 | interp_mode='bilinear',
120 | padding_mode='zeros',
121 | align_corners=True):
122 | """Warp an image or feature map with optical flow.
123 |
124 | Args:
125 | x (Tensor): Tensor with size (n, c, h, w).
126 | flow (Tensor): Tensor with size (n, h, w, 2), normal value.
127 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
128 | padding_mode (str): 'zeros' or 'border' or 'reflection'.
129 | Default: 'zeros'.
130 | align_corners (bool): Before pytorch 1.3, the default value is
131 | align_corners=True. After pytorch 1.3, the default value is
132 | align_corners=False. Here, we use the True as default.
133 |
134 | Returns:
135 | Tensor: Warped image or feature map.
136 | """
137 | assert x.size()[-2:] == flow.size()[1:3]
138 | _, _, h, w = x.size()
139 | # create mesh grid
140 | grid_y, grid_x = torch.meshgrid(
141 | torch.arange(0, h).type_as(x),
142 | torch.arange(0, w).type_as(x))
143 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
144 | grid.requires_grad = False
145 |
146 | vgrid = grid + flow
147 | # scale grid to [-1,1]
148 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
149 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
150 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
151 | output = F.grid_sample(
152 | x,
153 | vgrid_scaled,
154 | mode=interp_mode,
155 | padding_mode=padding_mode,
156 | align_corners=align_corners)
157 |
158 | # TODO, what if align_corners=False
159 | return output
160 |
161 |
162 | def resize_flow(flow,
163 | size_type,
164 | sizes,
165 | interp_mode='bilinear',
166 | align_corners=False):
167 | """Resize a flow according to ratio or shape.
168 |
169 | Args:
170 | flow (Tensor): Precomputed flow. shape [N, 2, H, W].
171 | size_type (str): 'ratio' or 'shape'.
172 | sizes (list[int | float]): the ratio for resizing or the final output
173 | shape.
174 | 1) The order of ratio should be [ratio_h, ratio_w]. For
175 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio
176 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
177 | ratio > 1.0).
178 | 2) The order of output_size should be [out_h, out_w].
179 | interp_mode (str): The mode of interpolation for resizing.
180 | Default: 'bilinear'.
181 | align_corners (bool): Whether align corners. Default: False.
182 |
183 | Returns:
184 | Tensor: Resized flow.
185 | """
186 | _, _, flow_h, flow_w = flow.size()
187 | if size_type == 'ratio':
188 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
189 | elif size_type == 'shape':
190 | output_h, output_w = sizes[0], sizes[1]
191 | else:
192 | raise ValueError(
193 | f'Size type should be ratio or shape, but got type {size_type}.')
194 |
195 | input_flow = flow.clone()
196 | ratio_h = output_h / flow_h
197 | ratio_w = output_w / flow_w
198 | input_flow[:, 0, :, :] *= ratio_w
199 | input_flow[:, 1, :, :] *= ratio_h
200 | resized_flow = F.interpolate(
201 | input=input_flow,
202 | size=(output_h, output_w),
203 | mode=interp_mode,
204 | align_corners=align_corners)
205 | return resized_flow
206 |
207 |
208 | # TODO: may write a cpp file
209 | def pixel_unshuffle(x, scale):
210 | """ Pixel unshuffle.
211 |
212 | Args:
213 | x (Tensor): Input feature with shape (b, c, hh, hw).
214 | scale (int): Downsample ratio.
215 |
216 | Returns:
217 | Tensor: the pixel unshuffled feature.
218 | """
219 | b, c, hh, hw = x.size()
220 | out_channel = c * (scale ** 2)
221 | assert hh % scale == 0 and hw % scale == 0
222 | h = hh // scale
223 | w = hw // scale
224 | x_view = x.view(b, c, h, scale, w, scale)
225 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
226 |
227 |
--------------------------------------------------------------------------------
/python/trustmark/KBNet/kb_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class LayerNormFunction(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, x, weight, bias, eps):
16 | ctx.eps = eps
17 | N, C, H, W = x.size()
18 | mu = x.mean(1, keepdim=True)
19 | var = (x - mu).pow(2).mean(1, keepdim=True)
20 | # print('mu, var', mu.mean(), var.mean())
21 | # d.append([mu.mean(), var.mean()])
22 | y = (x - mu) / (var + eps).sqrt()
23 | weight, bias, y = weight.contiguous(), bias.contiguous(), y.contiguous() # avoid cuda error
24 | ctx.save_for_backward(y, var, weight)
25 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
26 | return y
27 |
28 | @staticmethod
29 | def backward(ctx, grad_output):
30 | eps = ctx.eps
31 |
32 | N, C, H, W = grad_output.size()
33 | # y, var, weight = ctx.saved_variables
34 | y, var, weight = ctx.saved_tensors
35 | g = grad_output * weight.view(1, C, 1, 1)
36 | mean_g = g.mean(dim=1, keepdim=True)
37 |
38 | mean_gy = (g * y).mean(dim=1, keepdim=True)
39 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
40 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
41 | dim=0), None
42 |
43 |
44 | class LayerNorm2d(nn.Module):
45 |
46 | def __init__(self, channels, eps=1e-6, requires_grad=True):
47 | super(LayerNorm2d, self).__init__()
48 | self.register_parameter('weight', nn.Parameter(torch.ones(channels), requires_grad=requires_grad))
49 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels), requires_grad=requires_grad))
50 | self.eps = eps
51 |
52 | def forward(self, x):
53 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
54 |
55 |
56 | class SimpleGate(nn.Module):
57 | def forward(self, x):
58 | x1, x2 = x.chunk(2, dim=1)
59 | return x1 * x2
60 |
61 |
62 | class KBAFunction(torch.autograd.Function):
63 |
64 | @staticmethod
65 | def forward(ctx, x, att, selfk, selfg, selfb, selfw):
66 | B, nset, H, W = att.shape
67 | KK = selfk ** 2
68 | selfc = x.shape[1]
69 |
70 | att = att.reshape(B, nset, H * W).transpose(-2, -1)
71 |
72 | ctx.selfk, ctx.selfg, ctx.selfc, ctx.KK, ctx.nset = selfk, selfg, selfc, KK, nset
73 | ctx.x, ctx.att, ctx.selfb, ctx.selfw = x, att, selfb, selfw
74 |
75 | bias = att @ selfb
76 | attk = att @ selfw
77 |
78 | uf = torch.nn.functional.unfold(x, kernel_size=selfk, padding=selfk // 2)
79 |
80 | # for unfold att / less memory cost
81 | uf = uf.reshape(B, selfg, selfc // selfg * KK, H * W).permute(0, 3, 1, 2)
82 | attk = attk.reshape(B, H * W, selfg, selfc // selfg, selfc // selfg * KK)
83 |
84 | x = attk @ uf.unsqueeze(-1) #
85 | del attk, uf
86 | x = x.squeeze(-1).reshape(B, H * W, selfc) + bias
87 | x = x.transpose(-1, -2).reshape(B, selfc, H, W)
88 | return x
89 |
90 | @staticmethod
91 | def backward(ctx, grad_output):
92 | x, att, selfb, selfw = ctx.x, ctx.att, ctx.selfb, ctx.selfw
93 | selfk, selfg, selfc, KK, nset = ctx.selfk, ctx.selfg, ctx.selfc, ctx.KK, ctx.nset
94 |
95 | B, selfc, H, W = grad_output.size()
96 |
97 | dbias = grad_output.reshape(B, selfc, H * W).transpose(-1, -2)
98 |
99 | dselfb = att.transpose(-2, -1) @ dbias
100 | datt = dbias @ selfb.transpose(-2, -1)
101 |
102 | attk = att @ selfw
103 | uf = F.unfold(x, kernel_size=selfk, padding=selfk // 2)
104 | # for unfold att / less memory cost
105 | uf = uf.reshape(B, selfg, selfc // selfg * KK, H * W).permute(0, 3, 1, 2)
106 | attk = attk.reshape(B, H * W, selfg, selfc // selfg, selfc // selfg * KK)
107 |
108 | dx = dbias.view(B, H * W, selfg, selfc // selfg, 1)
109 |
110 | dattk = dx @ uf.view(B, H * W, selfg, 1, selfc // selfg * KK)
111 | duf = attk.transpose(-2, -1) @ dx
112 | del attk, uf
113 |
114 | dattk = dattk.view(B, H * W, -1)
115 | datt += dattk @ selfw.transpose(-2, -1)
116 | dselfw = att.transpose(-2, -1) @ dattk
117 |
118 | duf = duf.permute(0, 2, 3, 4, 1).view(B, -1, H * W)
119 | dx = F.fold(duf, output_size=(H, W), kernel_size=selfk, padding=selfk // 2)
120 |
121 | datt = datt.transpose(-1, -2).view(B, nset, H, W)
122 |
123 | return dx, datt, None, None, dselfb, dselfw
124 |
--------------------------------------------------------------------------------
/python/trustmark/KBNet/kbnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | import math
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as thf
12 | try:
13 | import lightning as pl
14 | except ImportError:
15 | import pytorch_lightning as pl
16 | import einops
17 | import kornia
18 | import numpy as np
19 | import torchvision
20 | import importlib
21 | from torchmetrics.functional import peak_signal_noise_ratio
22 | from contextlib import contextmanager
23 | from omegaconf import OmegaConf
24 |
25 |
26 | class WMRemoverKBNet(pl.LightningModule):
27 | def __init__(self,
28 | cover_key,
29 | secret_key,
30 | secret_embedder_config_path,
31 | secret_embedder_ckpt_path,
32 | denoise_config,
33 | grad_clip,
34 | ckpt_path="__none__",
35 | ):
36 | super().__init__()
37 | self.automatic_optimization = False # for GAN training
38 | self.cover_key = cover_key
39 | self.secret_key = secret_key
40 | self.grad_clip = grad_clip
41 |
42 | secret_embedder_config = OmegaConf.load(secret_embedder_config_path).model
43 | secret_embedder_config.params.ckpt_path = secret_embedder_ckpt_path
44 | self.secret_len = secret_embedder_config.params.secret_len
45 |
46 | self.secret_embedder = instantiate_from_config(secret_embedder_config).eval()
47 | for p in self.secret_embedder.parameters():
48 | p.requires_grad = False
49 |
50 | self.denoise = instantiate_from_config(denoise_config)
51 | if ckpt_path != "__none__":
52 | self.init_from_ckpt(ckpt_path, ignore_keys=[])
53 |
54 | def init_from_ckpt(self, path, ignore_keys=[]):
55 | sd = torch.load(path, map_location="cpu")["state_dict"]
56 | keys = list(sd.keys())
57 | for k in keys:
58 | for ik in ignore_keys:
59 | if k.startswith(ik):
60 | print("Deleting key {} from state_dict.".format(k))
61 | del sd[k]
62 | self.load_state_dict(sd, strict=False)
63 | print(f"Restored from {path}")
64 |
65 | @torch.no_grad()
66 | def get_input(self, batch, bs=None):
67 | image = batch[self.cover_key]
68 | secret = batch[self.secret_key]
69 | if bs is not None:
70 | image = image[:bs]
71 | secret = secret[:bs]
72 | else:
73 | bs = image.shape[0]
74 | # encode image 1st stage
75 | image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
76 | stego = self.secret_embedder(image, secret)[0]
77 | out = [stego, image, secret]
78 | return out
79 |
80 | def forward(self, x):
81 | return torch.clamp(self.denoise(x), -1, 1)
82 |
83 | def shared_step(self, batch, batch_idx):
84 | is_training = self.training
85 | x, y, s = self.get_input(batch)
86 | if is_training:
87 | opt_g = self.optimizers()
88 | x_denoised = self(x)
89 | loss = torch.abs(x_denoised - y).mean()
90 | s_pred = self.secret_embedder.decoder(x_denoised)
91 | loss_dict = {}
92 | loss_dict['total_loss'] = loss
93 | loss_dict['bit_acc'] = ((torch.sigmoid(s_pred.detach()) > 0.5).float() == s).float().mean()
94 | if is_training:
95 | self.manual_backward(loss)
96 | if self.grad_clip:
97 | # torch.nn.utils.clip_grad_norm_(self.denoise.parameters(), 0.01)
98 | self.clip_gradients(opt_g, gradient_clip_val=0.01, gradient_clip_algorithm="norm")
99 | opt_g.step()
100 | opt_g.zero_grad()
101 |
102 | loss_dict['psnr_denoise'] = peak_signal_noise_ratio(x_denoised.detach(), y.detach(), data_range=2.0)
103 | loss_dict['psnr_stego'] = peak_signal_noise_ratio(x.detach(), y.detach(), data_range=2.0)
104 |
105 | return loss, loss_dict
106 |
107 | def training_step(self, batch, batch_idx):
108 | loss, loss_dict = self.shared_step(batch, batch_idx)
109 | # logging
110 | loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
111 | self.log_dict(loss_dict, prog_bar=True,
112 | logger=True, on_step=True, on_epoch=True)
113 |
114 | self.log("global_step", float(self.global_step),
115 | prog_bar=True, logger=True, on_step=True, on_epoch=False)
116 | # if self.use_scheduler:
117 | # lr = self.optimizers().param_groups[0]['lr']
118 | # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
119 | sch = self.lr_schedulers()
120 | self.log('lr_abs', sch.get_lr()[0], prog_bar=True, logger=True, on_step=True, on_epoch=False)
121 | sch.step()
122 |
123 | # return loss
124 |
125 | @torch.no_grad()
126 | def validation_step(self, batch, batch_idx):
127 | _, loss_dict_no_ema = self.shared_step(batch, batch_idx)
128 | loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
129 | self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
130 |
131 | @torch.no_grad()
132 | def log_images(self, batch, fixed_input=False, **kwargs):
133 | log = dict()
134 | x, y, s = self.get_input(batch)
135 | x_denoised = self(x)
136 | log['clean'] = y
137 | log['stego'] = x
138 | log['denoised'] = x_denoised
139 | residual = x_denoised - y
140 | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1
141 | return log
142 |
143 | def configure_optimizers(self):
144 | lr = self.learning_rate
145 | params_g = list(self.denoise.parameters())
146 | opt_g = torch.optim.AdamW(params_g, lr=lr, weight_decay=1e-4, betas=(0.9, 0.999))
147 | lr_sch = lr_scheduler.CosineAnnealingRestartCyclicLR(
148 | opt_g, periods=[92000, 208000], restart_weights= [1,1], eta_mins=[0.0003,0.000001])
149 | return [opt_g], [lr_sch]
150 |
151 |
152 |
153 | def get_obj_from_str(string, reload=False):
154 | module, cls = string.rsplit(".", 1)
155 | if reload:
156 | module_imp = importlib.import_module(module)
157 | importlib.reload(module_imp)
158 | return getattr(importlib.import_module(module, package=None), cls)
159 |
160 |
161 |
162 | def instantiate_from_config(config):
163 | if not "target" in config:
164 | if config == '__is_first_stage__':
165 | return None
166 | elif config == "__is_unconditional__":
167 | return None
168 | raise KeyError("Expected key `target` to instantiate.")
169 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
170 |
171 |
172 |
--------------------------------------------------------------------------------
/python/trustmark/KBNet/kbnet_l_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | import math
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.nn.init as init
14 |
15 | from einops import rearrange
16 |
17 | from .kb_utils import KBAFunction
18 | from .kb_utils import LayerNorm2d, SimpleGate
19 |
20 |
21 | class Downsample(nn.Module):
22 | def __init__(self, n_feat):
23 | super(Downsample, self).__init__()
24 |
25 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
26 | nn.PixelUnshuffle(2))
27 |
28 | def forward(self, x):
29 | return self.body(x)
30 |
31 |
32 | class Upsample(nn.Module):
33 | def __init__(self, n_feat):
34 | super(Upsample, self).__init__()
35 |
36 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
37 | nn.PixelShuffle(2))
38 |
39 | def forward(self, x):
40 | return self.body(x)
41 |
42 |
43 | class OverlapPatchEmbed(nn.Module):
44 | def __init__(self, in_c=3, embed_dim=48, bias=False):
45 | super(OverlapPatchEmbed, self).__init__()
46 |
47 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
48 |
49 | def forward(self, x):
50 | x = self.proj(x)
51 | return x
52 |
53 |
54 | class TransAttention(nn.Module):
55 | def __init__(self, dim, num_heads, bias):
56 | super(TransAttention, self).__init__()
57 | self.num_heads = num_heads
58 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
59 |
60 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
61 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
62 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
63 |
64 | def forward(self, x):
65 | b, c, h, w = x.shape
66 |
67 | qkv = self.qkv_dwconv(self.qkv(x))
68 | q, k, v = qkv.chunk(3, dim=1)
69 |
70 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
71 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
72 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
73 |
74 | q = torch.nn.functional.normalize(q, dim=-1)
75 | k = torch.nn.functional.normalize(k, dim=-1)
76 |
77 | attn = (q @ k.transpose(-2, -1)) * self.temperature
78 | attn = attn.softmax(dim=-1)
79 |
80 | out = (attn @ v)
81 |
82 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
83 |
84 | out = self.project_out(out)
85 | return out
86 |
87 |
88 | class MFF(nn.Module):
89 | def __init__(self, dim, ffn_expansion_factor, bias, act=True, gc=2, nset=32, k=3):
90 | super(MFF, self).__init__()
91 | self.act = act
92 | self.gc = gc
93 |
94 | hidden_features = int(dim * ffn_expansion_factor)
95 |
96 | self.dwconv = nn.Sequential(
97 | nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias),
98 | nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
99 | groups=hidden_features, bias=bias),
100 | )
101 |
102 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
103 |
104 | self.sca = nn.Sequential(
105 | nn.AdaptiveAvgPool2d(1),
106 | nn.Conv2d(in_channels=dim, out_channels=hidden_features, kernel_size=1, padding=0, stride=1,
107 | groups=1, bias=True),
108 | )
109 | self.conv1 = nn.Sequential(
110 | nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias),
111 | nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
112 | groups=hidden_features, bias=bias),
113 | )
114 |
115 | c = hidden_features
116 | self.k, self.c = k, c
117 | self.nset = nset
118 |
119 | self.g = c // gc
120 | self.w = nn.Parameter(torch.zeros(1, nset, c * c // self.g * self.k ** 2))
121 | self.b = nn.Parameter(torch.zeros(1, nset, c))
122 | self.init_p(self.w, self.b)
123 | interc = min(dim, 24)
124 | # print(c, interc)
125 | self.conv2 = nn.Sequential(
126 | nn.Conv2d(in_channels=dim, out_channels=interc, kernel_size=3, padding=1, stride=1, groups=interc,
127 | bias=True),
128 | SimpleGate(),
129 | nn.Conv2d(interc // 2, self.nset, 1, padding=0, stride=1),
130 | )
131 | self.conv211 = nn.Conv2d(in_channels=dim, out_channels=self.nset, kernel_size=1)
132 | self.attgamma = nn.Parameter(torch.zeros((1, self.nset, 1, 1)) + 1e-2, requires_grad=True)
133 | self.ga1 = nn.Parameter(torch.zeros((1, hidden_features, 1, 1)) + 1e-2, requires_grad=True)
134 |
135 | def forward(self, x):
136 | sca = self.sca(x)
137 | x1 = self.dwconv(x)
138 |
139 | att = self.conv2(x) * self.attgamma + self.conv211(x)
140 | uf = self.conv1(x)
141 | x2 = self.KBA(uf, att, self.k, self.g, self.b, self.w) * self.ga1 + uf
142 |
143 | x = F.gelu(x1) * x2 if self.act else x1 * x2
144 | x = x * sca
145 | x = self.project_out(x)
146 | return x
147 |
148 | def init_p(self, weight, bias=None):
149 | init.kaiming_uniform_(weight, a=math.sqrt(5))
150 | if bias is not None:
151 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight)
152 | bound = 1 / math.sqrt(fan_in)
153 | init.uniform_(bias, -bound, bound)
154 |
155 | def KBA(self, x, att, selfk, selfg, selfb, selfw):
156 | return KBAFunction.apply(x, att, selfk, selfg, selfb, selfw)
157 |
158 |
159 | class KBBlock_l(nn.Module):
160 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias):
161 | super(KBBlock_l, self).__init__()
162 |
163 | self.norm1 = LayerNorm2d(dim)
164 | self.norm2 = LayerNorm2d(dim)
165 |
166 | self.attn = MFF(dim, ffn_expansion_factor, bias)
167 | self.ffn = TransAttention(dim, num_heads, bias)
168 |
169 | def forward(self, x):
170 | x = x + self.attn(self.norm1(x))
171 | x = x + self.ffn(self.norm2(x))
172 | return x
173 |
174 |
175 | class KBNet_l(nn.Module):
176 | def __init__(self, inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4,
177 | heads=[1, 2, 4, 8], ffn_expansion_factor=1.5, bias=False,
178 | blockname='KBBlock_l'):
179 | super(KBNet_l, self).__init__()
180 | # print('\r** ', blockname, end='')
181 | TransformerBlock = eval(blockname)
182 |
183 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
184 |
185 | self.encoder_level1 = nn.Sequential(*[
186 | TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in
187 | range(num_blocks[0])])
188 |
189 | self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
190 | self.encoder_level2 = nn.Sequential(*[
191 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
192 | bias=bias) for i in range(num_blocks[1])])
193 |
194 | self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3
195 | self.encoder_level3 = nn.Sequential(*[
196 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
197 | bias=bias) for i in range(num_blocks[2])])
198 |
199 | self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4
200 | self.latent = nn.Sequential(*[
201 | TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor,
202 | bias=bias) for i in range(num_blocks[3])])
203 |
204 | self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3
205 | self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
206 | self.decoder_level3 = nn.Sequential(*[
207 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
208 | bias=bias) for i in range(num_blocks[2])])
209 |
210 | self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2
211 | self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
212 | self.decoder_level2 = nn.Sequential(*[
213 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
214 | bias=bias) for i in range(num_blocks[1])])
215 |
216 | self.up2_1 = Upsample(int(dim * 2 ** 1))
217 |
218 | self.decoder_level1 = nn.Sequential(*[
219 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
220 | bias=bias) for i in range(num_blocks[0])])
221 |
222 | self.refinement = nn.Sequential(*[
223 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
224 | bias=bias) for i in range(num_refinement_blocks)])
225 |
226 | self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
227 |
228 | def forward(self, inp_img):
229 | inp_enc_level1 = self.patch_embed(inp_img)
230 | out_enc_level1 = self.encoder_level1(inp_enc_level1)
231 |
232 | inp_enc_level2 = self.down1_2(out_enc_level1)
233 | out_enc_level2 = self.encoder_level2(inp_enc_level2)
234 |
235 | inp_enc_level3 = self.down2_3(out_enc_level2)
236 | out_enc_level3 = self.encoder_level3(inp_enc_level3)
237 |
238 | inp_enc_level4 = self.down3_4(out_enc_level3)
239 | latent = self.latent(inp_enc_level4)
240 |
241 | inp_dec_level3 = self.up4_3(latent)
242 | inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
243 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
244 | out_dec_level3 = self.decoder_level3(inp_dec_level3)
245 |
246 | inp_dec_level2 = self.up3_2(out_dec_level3)
247 | inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
248 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
249 | out_dec_level2 = self.decoder_level2(inp_dec_level2)
250 |
251 | inp_dec_level1 = self.up2_1(out_dec_level2)
252 | inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
253 | out_dec_level1 = self.decoder_level1(inp_dec_level1)
254 |
255 | out_dec_level1 = self.refinement(out_dec_level1)
256 |
257 | out_dec_level1 = self.output(out_dec_level1) + inp_img
258 |
259 | return out_dec_level1
260 |
--------------------------------------------------------------------------------
/python/trustmark/KBNet/kbnet_s_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | import math
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.nn.init as init
14 |
15 | from .kb_utils import KBAFunction
16 | from .kb_utils import LayerNorm2d, SimpleGate
17 |
18 |
19 | class KBBlock_s(nn.Module):
20 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, nset=32, k=3, gc=4, lightweight=False):
21 | super(KBBlock_s, self).__init__()
22 | self.k, self.c = k, c
23 | self.nset = nset
24 | dw_ch = int(c * DW_Expand)
25 | ffn_ch = int(FFN_Expand * c)
26 |
27 | self.g = c // gc
28 | self.w = nn.Parameter(torch.zeros(1, nset, c * c // self.g * self.k ** 2))
29 | self.b = nn.Parameter(torch.zeros(1, nset, c))
30 | self.init_p(self.w, self.b)
31 |
32 | self.norm1 = LayerNorm2d(c)
33 | self.norm2 = LayerNorm2d(c)
34 |
35 | self.sca = nn.Sequential(
36 | nn.AdaptiveAvgPool2d(1),
37 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1,
38 | groups=1, bias=True),
39 | )
40 |
41 | if not lightweight:
42 | self.conv11 = nn.Sequential(
43 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1,
44 | bias=True),
45 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=5, padding=2, stride=1, groups=c // 4,
46 | bias=True),
47 | )
48 | else:
49 | self.conv11 = nn.Sequential(
50 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1,
51 | bias=True),
52 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, padding=1, stride=1, groups=c,
53 | bias=True),
54 | )
55 |
56 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1,
57 | bias=True)
58 | self.conv21 = nn.Conv2d(in_channels=c, out_channels=c, kernel_size=3, padding=1, stride=1, groups=c,
59 | bias=True)
60 |
61 | interc = min(c, 32)
62 | self.conv2 = nn.Sequential(
63 | nn.Conv2d(in_channels=c, out_channels=interc, kernel_size=3, padding=1, stride=1, groups=interc,
64 | bias=True),
65 | SimpleGate(),
66 | nn.Conv2d(interc // 2, self.nset, 1, padding=0, stride=1),
67 | )
68 |
69 | self.conv211 = nn.Conv2d(in_channels=c, out_channels=self.nset, kernel_size=1)
70 |
71 | self.conv3 = nn.Conv2d(in_channels=dw_ch // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
72 | groups=1, bias=True)
73 |
74 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_ch, kernel_size=1, padding=0, stride=1, groups=1,
75 | bias=True)
76 | self.conv5 = nn.Conv2d(in_channels=ffn_ch // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
77 | groups=1, bias=True)
78 |
79 | self.dropout1 = nn.Identity()
80 | self.dropout2 = nn.Identity()
81 |
82 | self.ga1 = nn.Parameter(torch.zeros((1, c, 1, 1)) + 1e-2, requires_grad=True)
83 | self.attgamma = nn.Parameter(torch.zeros((1, self.nset, 1, 1)) + 1e-2, requires_grad=True)
84 | self.sg = SimpleGate()
85 |
86 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)) + 1e-2, requires_grad=True)
87 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)) + 1e-2, requires_grad=True)
88 |
89 | def init_p(self, weight, bias=None):
90 | init.kaiming_uniform_(weight, a=math.sqrt(5))
91 | if bias is not None:
92 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight)
93 | bound = 1 / math.sqrt(fan_in)
94 | init.uniform_(bias, -bound, bound)
95 |
96 | def KBA(self, x, att, selfk, selfg, selfb, selfw):
97 | return KBAFunction.apply(x, att, selfk, selfg, selfb, selfw)
98 |
99 | def forward(self, inp):
100 | x = inp
101 |
102 | x = self.norm1(x)
103 | sca = self.sca(x)
104 | x1 = self.conv11(x)
105 |
106 | # KBA module
107 | att = self.conv2(x) * self.attgamma + self.conv211(x)
108 | uf = self.conv21(self.conv1(x))
109 | x = self.KBA(uf, att, self.k, self.g, self.b, self.w) * self.ga1 + uf
110 | x = x * x1 * sca
111 |
112 | x = self.conv3(x)
113 | x = self.dropout1(x)
114 | y = inp + x * self.beta
115 |
116 | # FFN
117 | x = self.norm2(y)
118 | x = self.conv4(x)
119 | x = self.sg(x)
120 | x = self.conv5(x)
121 |
122 | x = self.dropout2(x)
123 | return y + x * self.gamma
124 |
125 |
126 | class KBNet_s(nn.Module):
127 | def __init__(self, img_channel=3, width=64, middle_blk_num=12, enc_blk_nums=[2, 2, 4, 8],
128 | dec_blk_nums=[2, 2, 2, 2], basicblock='KBBlock_s', lightweight=False, ffn_scale=2):
129 | super().__init__()
130 | basicblock = eval(basicblock)
131 |
132 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1,
133 | groups=1, bias=True)
134 |
135 | self.encoders = nn.ModuleList()
136 | self.middle_blks = nn.ModuleList()
137 | self.decoders = nn.ModuleList()
138 |
139 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1,
140 | groups=1, bias=True)
141 |
142 | self.ups = nn.ModuleList()
143 | self.downs = nn.ModuleList()
144 |
145 | chan = width
146 | for num in enc_blk_nums:
147 | self.encoders.append(
148 | nn.Sequential(
149 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)]
150 | )
151 | )
152 | self.downs.append(
153 | nn.Conv2d(chan, 2 * chan, 2, 2)
154 | )
155 | chan = chan * 2
156 |
157 | self.middle_blks = \
158 | nn.Sequential(
159 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(middle_blk_num)]
160 | )
161 |
162 | for num in dec_blk_nums:
163 | self.ups.append(
164 | nn.Sequential(
165 | nn.Conv2d(chan, chan * 2, 1, bias=False),
166 | nn.PixelShuffle(2)
167 | )
168 | )
169 | chan = chan // 2
170 | self.decoders.append(
171 | nn.Sequential(
172 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)]
173 | )
174 | )
175 |
176 | self.padder_size = 2 ** len(self.encoders)
177 |
178 | def forward(self, inp):
179 | B, C, H, W = inp.shape
180 | inp = self.check_image_size(inp)
181 | x = self.intro(inp)
182 |
183 | encs = []
184 |
185 | for encoder, down in zip(self.encoders, self.downs):
186 | x = encoder(x)
187 | encs.append(x)
188 | x = down(x)
189 |
190 | x = self.middle_blks(x)
191 |
192 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
193 | x = up(x)
194 | x = x + enc_skip
195 | x = decoder(x)
196 |
197 | x = self.ending(x)
198 | x = x + inp
199 |
200 | return x[:, :, :H, :W]
201 |
202 | def check_image_size(self, x):
203 | _, _, h, w = x.size()
204 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
205 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
206 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
207 | return x
208 |
209 |
210 | class SecretEncoderKBNet(nn.Module):
211 | def __init__(self, resolution=256, img_channel=3, secret_len=100, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 2, 4],
212 | dec_blk_nums=[1, 1, 1, 1], basicblock='KBBlock_s', lightweight=True, ffn_scale=1.5):
213 | super().__init__()
214 | basicblock = eval(basicblock)
215 | assert resolution % 16 == 0, 'resolution must be divisible by 16'
216 | self.secret_len = secret_len
217 |
218 | self.intro = nn.Conv2d(in_channels=img_channel*2, out_channels=width, kernel_size=3, padding=1, stride=1,
219 | groups=1, bias=True)
220 | self.secret_pre = nn.Linear(secret_len, 16*16*img_channel)
221 | self.secret_up = nn.Upsample(scale_factor=(resolution//16, resolution//16))
222 |
223 | self.encoders = nn.ModuleList()
224 | self.middle_blks = nn.ModuleList()
225 | self.decoders = nn.ModuleList()
226 |
227 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1,
228 | groups=1, bias=True)
229 |
230 | self.ups = nn.ModuleList()
231 | self.downs = nn.ModuleList()
232 | self.skip = nn.ModuleList()
233 |
234 | chan = width
235 | for num in enc_blk_nums:
236 | self.encoders.append(
237 | nn.Sequential(
238 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)]
239 | )
240 | )
241 | self.downs.append(
242 | nn.Conv2d(chan, 2 * chan, 2, 2)
243 | )
244 | chan = chan * 2
245 |
246 | self.middle_blks = \
247 | nn.Sequential(
248 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(middle_blk_num)]
249 | )
250 |
251 | for num in dec_blk_nums:
252 | self.ups.append(
253 | nn.Sequential(
254 | # nn.Conv2d(chan, chan * 2, 1, bias=False),
255 | # nn.PixelShuffle(2)
256 | nn.Upsample(scale_factor=(2,2)),
257 | nn.ZeroPad2d((0, 1, 0, 1)),
258 | nn.Conv2d(chan, chan // 2, 2, 1),
259 |
260 | )
261 | )
262 | self.skip.append(
263 | nn.Sequential(
264 | nn.Conv2d(chan, chan//2, 3, 1, 1),
265 | nn.ReLU(inplace=True),
266 | )
267 | )
268 | chan = chan // 2
269 | self.decoders.append(
270 | nn.Sequential(
271 | *[basicblock(chan, FFN_Expand=ffn_scale, lightweight=lightweight) for _ in range(num)]
272 | )
273 | )
274 |
275 | self.padder_size = 2 ** len(self.encoders)
276 |
277 | def forward(self, image, secret):
278 | H, W = image.shape[-2:]
279 | image = self.check_image_size(image)
280 | secret = F.relu(self.secret_pre(secret))
281 | secret = secret.view(-1, image.shape[1], 16, 16)
282 | secret = self.secret_up(secret)
283 | imgsec = torch.cat([image, secret], dim=1) # B, 6, 256, 256
284 | x = self.intro(imgsec)
285 |
286 | encs = []
287 |
288 | for encoder, down in zip(self.encoders, self.downs):
289 | x = encoder(x)
290 | encs.append(x)
291 | x = down(x)
292 |
293 | x = self.middle_blks(x)
294 | # import pdb; pdb.set_trace()
295 | for decoder, up, enc_skip, skip in zip(self.decoders, self.ups, encs[::-1], self.skip):
296 | x = up(x)
297 | # x = x + enc_skip
298 | x = torch.cat([enc_skip, x], dim=1)
299 | x = skip(x)
300 | x = decoder(x)
301 |
302 | x = self.ending(x) # B, 3, 256, 256
303 | # x = x + image
304 | x = torch.tanh(x)
305 |
306 | return x[:, :, :H, :W]
307 |
308 | def check_image_size(self, x):
309 | _, _, h, w = x.size()
310 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
311 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
312 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
313 | return x
314 |
315 |
--------------------------------------------------------------------------------
/python/trustmark/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 | from .trustmark import TrustMark
8 |
--------------------------------------------------------------------------------
/python/trustmark/bchecc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 |
9 | from dataclasses import dataclass
10 | from copy import deepcopy
11 |
12 |
13 | class BCH(object):
14 |
15 | @dataclass
16 | class params:
17 | m: int
18 | t: int
19 | poly: int
20 |
21 | @dataclass
22 | class polynomial:
23 | deg: int
24 |
25 |
26 |
27 | ### GALOIS OPERATIONS
28 |
29 | def g_inv(self,a):
30 | return self.ECCstate.exponents[self.ECCstate.n - self.ECCstate.logarithms[a]]
31 |
32 | def g_sqrt(self, a):
33 | if a:
34 | return self.ECCstate.exponents[self.mod(2*self.ECCstate.logarithms[a])]
35 | else:
36 | return 0
37 |
38 | def mod(self, v):
39 | if v0 and b>0):
47 | res=self.mod(self.ECCstate.logarithms[a]+self.ECCstate.logarithms[b])
48 | return (self.ECCstate.exponents[res])
49 | else:
50 | return 0
51 |
52 | def g_div(self,a,b):
53 | if a:
54 | return self.ECCstate.exponents[self.mod(self.ECCstate.logarithms[a]+self.ECCstate.n-self.ECCstate.logarithms[b])]
55 | else:
56 | return 0
57 |
58 | def modn(self, v):
59 | n=self.ECCstate.n
60 | while (v>=n):
61 | v -= n
62 | v = (v & n) + (v >> self.ECCstate.m)
63 | return v
64 |
65 | def g_log(self, x):
66 | return self.ECCstate.logarithms[x]
67 |
68 | def a_ilog(self, x):
69 | return self.mod(self.ECCstate.n- self.ECCstate.logarithms[x])
70 |
71 | def g_pow(self, i):
72 | return self.ECCstate.exponents[self.modn(i)]
73 |
74 | def deg(self, x):
75 | count=0
76 | while (x >> 1):
77 | x = x >> 1
78 | count += 1
79 | return count
80 |
81 |
82 | def ceilop(self, a, b):
83 | return int((a + b - 1) / b)
84 |
85 | def load4bytes(self, data):
86 | w=0
87 | w += data[0] << 24
88 | w += data[1] << 16
89 | w += data[2] << 8
90 | w += data[3] << 0
91 | return w
92 |
93 |
94 | def getroots(self, k, poly):
95 |
96 | roots=[]
97 |
98 | if poly.deg>2:
99 | k=k*8+self.ECCstate.ecc_bits
100 |
101 | rep=[0]*(self.ECCstate.t*2)
102 | d=poly.deg
103 | l=self.ECCstate.n-self.g_log(poly.c[poly.deg])
104 | for i in range(0,d):
105 | if poly.c[i]:
106 | rep[i]=self.mod(self.g_log(poly.c[i])+l)
107 | else:
108 | rep[i]=-1
109 |
110 | rep[poly.deg]=0
111 | syn0=self.g_div(poly.c[0],poly.c[poly.deg])
112 | for i in range(self.ECCstate.n-k+1, self.ECCstate.n+1):
113 | syn=syn0
114 | for j in range(1,poly.deg+1):
115 | m=rep[j]
116 | if m>=0:
117 | syn = syn ^ self.g_pow(m+j*i)
118 | if syn==0:
119 | roots.append(self.ECCstate.n-i)
120 | if len(roots)==poly.deg:
121 | break
122 | if len(roots)0):
171 | w=self.load4bytes(recvecc[offset:(offset+4)])
172 | eccbuf.append(w)
173 | offset+=4
174 | mlen -=1
175 | recvecc=recvecc[offset:]
176 | leftdata=len(recvecc)
177 | if leftdata>0: #pad it to 4
178 | recvecc=recvecc+bytes([0]*(4-leftdata))
179 | w=self.load4bytes(recvecc)
180 | eccbuf.append(w)
181 |
182 | eccwords=self.ceilop(self.ECCstate.m*self.ECCstate.t, 32)
183 |
184 | sum=0
185 | for i in range(0,eccwords):
186 | self.ECCstate.ecc_buf[i] = self.ECCstate.ecc_buf[i] ^ eccbuf[i]
187 | sum = sum | self.ECCstate.ecc_buf[i]
188 | if sum==0:
189 | return 0 # no bit flips
190 |
191 |
192 | s=self.ECCstate.ecc_bits
193 | t=self.ECCstate.t
194 | syn=[0]*(2*t)
195 |
196 | m= s & 31
197 |
198 | synbuf=self.ECCstate.ecc_buf
199 |
200 | if (m):
201 | synbuf[int(s/32)] = synbuf[int(s/32)] & ~(pow(2,32-m)-1)
202 |
203 | synptr=0
204 | while(s>0 or synptr==0):
205 | poly=synbuf[synptr]
206 | synptr += 1
207 | s-= 32
208 | while (poly):
209 | i=self.deg(poly)
210 | for j in range(0,(2*t),2):
211 | syn[j]=syn[j] ^ self.g_pow((j+1)*(i+s))
212 | poly = poly ^ pow(2,i)
213 |
214 |
215 | for i in range(0,t):
216 | syn[2*i+1]=self.g_sqrt(syn[i])
217 |
218 |
219 | n=self.ECCstate.n
220 | t=self.ECCstate.t
221 | pp=-1
222 | pd=1
223 |
224 | pelp=self.polynomial(deg=0)
225 | pelp.deg=0
226 | pelp.c= [0]*(2*t)
227 | pelp.c[0]=1
228 |
229 | elp=self.polynomial(deg=0)
230 | elp.c= [0]*(2*t)
231 | elp.c[0]=1
232 |
233 | d=syn[0]
234 |
235 | elp_copy=self.polynomial(deg=0)
236 | for i in range(0,t):
237 | if (elp.deg>t):
238 | break
239 | if d:
240 | k=2*i-pp
241 | elp_copy=deepcopy(elp)
242 | tmp=self.g_log(d)+n-self.g_log(pd)
243 | for j in range(0,(pelp.deg+1)):
244 | if (pelp.c[j]):
245 | l=self.g_log(pelp.c[j])
246 | elp.c[j+k]=elp.c[j+k] ^ self.g_pow(tmp+l)
247 |
248 |
249 | tmp=pelp.deg+k
250 | if tmp>elp.deg:
251 | elp.deg=tmp
252 | pelp=deepcopy(elp_copy)
253 | pd=d
254 | pp=2*i
255 | if (i= nbits:
268 | return -1
269 | self.ECCstate.errloc[i]=nbits-1-self.ECCstate.errloc[i]
270 | self.ECCstate.errloc[i]=(self.ECCstate.errloc[i] & ~7) | (7-(self.ECCstate.errloc[i] & 7))
271 |
272 |
273 | for bitflip in self.ECCstate.errloc:
274 | byte= int (bitflip / 8)
275 | bit = pow(2,(bitflip & 7))
276 | if bitflip < (len(data)+len(recvecc))*8:
277 | if byte0):
304 | w=self.load4bytes(data[offset:(offset+4)])
305 | w=w^r[0]
306 | p0=tab0idx+(l+1)*((w>>0) & 0xff)
307 | p1=tab1idx+(l+1)*((w>>8) & 0xff)
308 | p2=tab2idx+(l+1)*((w>>16) & 0xff)
309 | p3=tab3idx+(l+1)*((w>>24) & 0xff)
310 |
311 | for i in range(0,l):
312 | r[i]=r[i+1] ^ self.ECCstate.cyclic_tab[p0+i] ^ self.ECCstate.cyclic_tab[p1+i] ^ self.ECCstate.cyclic_tab[p2+i] ^ self.ECCstate.cyclic_tab[p3+i]
313 |
314 | r[l] = self.ECCstate.cyclic_tab[p0+l]^self.ECCstate.cyclic_tab[p1+l]^self.ECCstate.cyclic_tab[p2+l]^self.ECCstate.cyclic_tab[p3+l];
315 | mlen -=1
316 | offset +=4
317 |
318 |
319 | data=data[offset:]
320 | leftdata=len(data)
321 |
322 | ecc=r
323 | posn=0
324 | while (leftdata):
325 | tmp=data[posn]
326 | posn += 1
327 | pidx = (l+1)*(((ecc[0] >> 24)^(tmp)) & 0xff)
328 | for i in range(0,l):
329 | ecc[i]=(((ecc[i] << 8)&0xffffffff)|ecc[i+1]>>24)^(self.ECCstate.cyclic_tab[pidx])
330 | pidx += 1
331 | ecc[l]=((ecc[l] << 8)&0xffffffff)^(self.ECCstate.cyclic_tab[pidx])
332 | leftdata -= 1
333 |
334 | self.ECCstate.ecc_buf=ecc
335 | eccout=[]
336 | for e in r:
337 | eccout.append((e >> 24) & 0xff)
338 | eccout.append((e >> 16) & 0xff)
339 | eccout.append((e >> 8) & 0xff)
340 | eccout.append((e >> 0) & 0xff)
341 |
342 | eccout=eccout[0:self.ECCstate.ecc_bytes]
343 |
344 | eccbytes=(bytearray(bytes(eccout)))
345 | return eccbytes
346 |
347 |
348 |
349 | def build_cyclic(self, g):
350 |
351 | l=self.ceilop(self.ECCstate.m*self.ECCstate.t, 32)
352 |
353 | plen=self.ceilop(self.ECCstate.ecc_bits+1,32)
354 | ecclen=self.ceilop(self.ECCstate.ecc_bits,32)
355 |
356 | self.ECCstate.cyclic_tab = [0] * 4*256*l
357 |
358 | for i in range(0,256):
359 | for b in range(0,4):
360 | offset= (b*256+i)*l
361 | data = i << 8*b
362 | while (data):
363 |
364 | d=self.deg(data)
365 | data = data ^ (g[0] >> (31-d))
366 | for j in range(0,ecclen):
367 | if d<31:
368 | hi=(g[j] << (d+1)) & 0xffffffff
369 | else:
370 | hi=0
371 | if j+1 < plen:
372 | lo= g[j+1] >> (31-d)
373 | else:
374 | lo= 0
375 | self.ECCstate.cyclic_tab[j+offset] = self.ECCstate.cyclic_tab[j+offset] ^ (hi | lo)
376 |
377 |
378 | def __init__(self, t, poly):
379 |
380 | tmp = poly;
381 | m = 0;
382 | while (tmp >> 1):
383 | tmp =tmp >> 1
384 | m +=1
385 |
386 | self.ECCstate=self.params(m=m,t=t,poly=poly)
387 |
388 | self.ECCstate.n=pow(2,m)-1
389 | words = self.ceilop(m*t,32)
390 | self.ECCstate.ecc_bytes = self.ceilop(m*t,8)
391 | self.ECCstate.cyclic_tab=[0]*(words*1024)
392 | self.ECCstate.syn=[0]*(2*t)
393 | self.ECCstate.elp=[0]*(t+1)
394 | self.ECCstate.errloc=[0] * t
395 |
396 |
397 | x=1
398 | k=pow(2,self.deg(poly))
399 | if k != pow(2,self.ECCstate.m):
400 | return -1
401 |
402 | self.ECCstate.exponents=[0]*(1+self.ECCstate.n)
403 | self.ECCstate.logarithms=[0]*(1+self.ECCstate.n)
404 | self.ECCstate.elp_pre=[0]*(1+self.ECCstate.m)
405 |
406 | for i in range(0,self.ECCstate.n):
407 | self.ECCstate.exponents[i]=x
408 | self.ECCstate.logarithms[x]=i
409 | if i and x==1:
410 | return -1
411 | x*= 2
412 | if (x & k):
413 | x=x^poly
414 |
415 | self.ECCstate.logarithms[0]=0
416 | self.ECCstate.exponents[self.ECCstate.n]=1
417 |
418 |
419 |
420 | n=0
421 | g=self.polynomial(deg=0)
422 | g.c=[0]*((m*t)+1)
423 | roots=[0]*(self.ECCstate.n+1)
424 | genpoly=[0]*self.ceilop(m*t+1,32)
425 |
426 | # enum all roots
427 | for i in range(0,t):
428 | r=2*i+1
429 | for j in range(0,m):
430 | roots[r]=1
431 | r=self.mod(2*r)
432 |
433 | # build g(x)
434 | g.deg=0
435 | g.c[0]=1
436 | for i in range(0,self.ECCstate.n):
437 | if roots[i]:
438 | r=self.ECCstate.exponents[i]
439 | g.c[g.deg+1]=1
440 | for j in range(g.deg,0,-1):
441 | g.c[j]=self.g_mul(g.c[j],r)^g.c[j-1]
442 | g.c[0]=self.g_mul(g.c[0],r)
443 | g.deg += 1
444 |
445 | # store
446 | n = g.deg+1
447 | i = 0
448 |
449 | while (n>0) :
450 |
451 | if n>32:
452 | nbits=32
453 | else:
454 | nbits=n
455 |
456 | word=0
457 | for j in range (0,nbits):
458 | if g.c[n-1-j] :
459 | word = word | pow(2,31-j)
460 | genpoly[i]=word
461 | i += 1
462 | n -= nbits
463 | self.ECCstate.ecc_bits=g.deg
464 |
465 | self.build_cyclic(genpoly);
466 |
467 |
468 | sum=0
469 | aexp=0
470 | for i in range(0,m):
471 | for j in range(0,m):
472 | sum = sum ^ self.g_pow(i*pow(2,j))
473 | if sum:
474 | aexp=self.ECCstate.exponents[i]
475 | break
476 |
477 | x=0
478 | precomp=[0] * 31
479 | remaining=m
480 |
481 | while (x<= self.ECCstate.n and remaining):
482 | y=self.g_sqrt(x)^x
483 | for i in range(0,2):
484 | r=self.g_log(y)
485 | if (y and (r b c h w").contiguous()
213 | n = torch.multinomial(torch.tensor([0.5,0.3,0.2]), 1).item() + 1
214 | stego = image
215 | for i in range(n):
216 | secret = torch.zeros_like(secret).random_(0, 2)
217 | stego = self.secret_embedder(stego, secret)[0]
218 | # stego = self.secret_embedder(image, secret)[0]
219 | out = [stego, image, secret]
220 | return out
221 |
222 | def forward(self, x):
223 | return torch.clamp(self.denoise(x), -1, 1)
224 |
225 |
226 | @torch.no_grad()
227 | def log_images(self, batch, fixed_input=False, **kwargs):
228 | log = dict()
229 | x, y, s = self.get_input(batch)
230 | x_denoised = self(x)
231 | log['clean'] = y
232 | log['stego'] = x
233 | log['denoised'] = x_denoised
234 | residual = x_denoised - y
235 | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1
236 | return log
237 |
238 |
239 |
240 | class SimpleUnet(nn.Module):
241 | def __init__(self, dim=32) -> None:
242 | super().__init__()
243 | self.conv1 = nn.Conv2d(3, dim, 3, 1, 1)
244 | self.conv2 = nn.Conv2d(dim, dim, 3, 2, 1)
245 | self.conv3 = nn.Conv2d(dim, dim*2, 3, 2, 1)
246 | self.conv4 = nn.Conv2d(dim*2, dim*4, 3, 2, 1)
247 | self.conv5 = nn.Conv2d(dim*4, dim*8, 3, 2, 1)
248 | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1))
249 | self.up6 = nn.Conv2d(dim*8, dim*4, 2, 1)
250 | self.upsample6 = nn.Upsample(scale_factor=(2, 2))
251 | self.conv6 = nn.Conv2d(dim*4 + dim*4, dim*4, 3, 1, 1)
252 | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1))
253 | self.up7 = nn.Conv2d(dim*4, dim*2, 2, 1)
254 | self.upsample7 = nn.Upsample(scale_factor=(2, 2))
255 | self.conv7 = nn.Conv2d(dim*2 + dim*2, dim*2, 3, 1, 1)
256 | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1))
257 | self.up8 = nn.Conv2d(dim*2, dim, 2, 1)
258 | self.upsample8 = nn.Upsample(scale_factor=(2, 2))
259 | self.conv8 = nn.Conv2d(dim+dim, dim, 3, 1, 1)
260 | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1))
261 | self.up9 = nn.Conv2d(dim, dim, 2, 1)
262 | self.upsample9 = nn.Upsample(scale_factor=(2, 2))
263 | self.conv9 = nn.Conv2d(dim + dim + 3, dim, 3, 1, 1)
264 | self.conv10 = nn.Conv2d(dim, dim, 3, 1, 1)
265 | self.post = nn.Conv2d(dim, dim//2, 1)
266 | self.silu = nn.SiLU()
267 | self.out = nn.Conv2d(dim//2, 3, 1)
268 |
269 | def forward(self, image):
270 | inputs = image
271 |
272 | conv1 = thf.relu(self.conv1(inputs))
273 | conv2 = thf.relu(self.conv2(conv1))
274 | conv3 = thf.relu(self.conv3(conv2))
275 | conv4 = thf.relu(self.conv4(conv3))
276 | conv5 = thf.relu(self.conv5(conv4))
277 | up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5))))
278 | merge6 = torch.cat([conv4, up6], dim=1)
279 | conv6 = thf.relu(self.conv6(merge6))
280 | up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6))))
281 | merge7 = torch.cat([conv3, up7], dim=1)
282 | conv7 = thf.relu(self.conv7(merge7))
283 | up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7))))
284 | merge8 = torch.cat([conv2, up8], dim=1)
285 | conv8 = thf.relu(self.conv8(merge8))
286 | up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8))))
287 | merge9 = torch.cat([conv1, up9, inputs], dim=1)
288 | conv9 = thf.relu(self.conv9(merge9))
289 | conv10 = thf.relu(self.conv10(conv9))
290 | post = self.silu(self.post(conv10))
291 | out = thf.tanh(self.out(post))
292 | return out
293 |
294 |
295 |
296 |
297 | def get_obj_from_str(string, reload=False):
298 | module, cls = string.rsplit(".", 1)
299 | if reload:
300 | module_imp = importlib.import_module(module)
301 | importlib.reload(module_imp)
302 | return getattr(importlib.import_module(module, package=None), cls)
303 |
304 |
305 |
306 | def instantiate_from_config(config):
307 | if not "target" in config:
308 | if config == '__is_first_stage__':
309 | return None
310 | elif config == "__is_unconditional__":
311 | return None
312 | raise KeyError("Expected key `target` to instantiate.")
313 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
314 |
315 |
316 |
--------------------------------------------------------------------------------
/python/trustmark/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe
2 | # All Rights Reserved.
3 |
4 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | # accordance with the terms of the Adobe license agreement accompanying
6 | # it.
7 |
8 | import math
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as thf
12 | try:
13 | import lightning as pl
14 | except ImportError:
15 | import pytorch_lightning as pl
16 | import einops
17 | import numpy as np
18 | import torchvision
19 | import importlib
20 | from torchmetrics.functional import peak_signal_noise_ratio
21 | from contextlib import contextmanager
22 |
23 |
24 | class Identity(nn.Module):
25 | def __init__(self,*args,**kwargs):
26 | super().__init__()
27 | def forward(self, x):
28 | return x
29 |
30 |
31 | class TrustMark_Arch(pl.LightningModule):
32 | def __init__(self,
33 | cover_key,
34 | secret_key,
35 | secret_len,
36 | resolution,
37 | secret_encoder_config,
38 | secret_decoder_config,
39 | discriminator_config,
40 | loss_config,
41 | bit_acc_thresholds=[0.9, 0.95, 0.98],
42 | noise_config='__none__',
43 | ckpt_path="__none__",
44 | lr_scheduler='__none__',
45 | use_ema=False
46 | ):
47 | super().__init__()
48 | self.automatic_optimization = False
49 | self.cover_key = cover_key
50 | self.secret_key = secret_key
51 | secret_encoder_config.params.secret_len = secret_len
52 | secret_decoder_config.params.secret_len = secret_len
53 | secret_encoder_config.params.resolution = resolution
54 | secret_decoder_config.params.resolution = 224
55 | self.encoder = instantiate_from_config(secret_encoder_config)
56 | self.decoder = instantiate_from_config(secret_decoder_config)
57 | self.loss_layer = instantiate_from_config(loss_config)
58 | self.discriminator = instantiate_from_config(discriminator_config)
59 |
60 | if noise_config != '__none__':
61 | self.noise = instantiate_from_config(noise_config)
62 |
63 | self.lr_scheduler = None if lr_scheduler == '__none__' else lr_scheduler
64 |
65 | self.use_ema = use_ema
66 | if self.use_ema:
67 | print('Using EMA')
68 | self.encoder_ema = LitEma(self.encoder)
69 | self.decoder_ema = LitEma(self.decoder)
70 | self.discriminator_ema = LitEma(self.discriminator)
71 | print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()) + list(self.discriminator_ema.buffers()))}.")
72 |
73 | if ckpt_path != "__none__":
74 | self.init_from_ckpt(ckpt_path, ignore_keys=[])
75 |
76 | # early training phase
77 | self.fixed_img = None
78 | self.fixed_secret = None
79 | self.register_buffer("fixed_input", torch.tensor(True))
80 | self.register_buffer("update_gen", torch.tensor(False)) # update generator to fool discriminator
81 | self.bit_acc_thresholds = bit_acc_thresholds
82 | self.crop = Identity()
83 |
84 | def init_from_ckpt(self, path, ignore_keys=list()):
85 | sd = torch.load(path, map_location="cpu")["state_dict"]
86 | keys = list(sd.keys())
87 | for k in keys:
88 | for ik in ignore_keys:
89 | if k.startswith(ik):
90 | print("Deleting key {} from state_dict.".format(k))
91 | del sd[k]
92 | self.load_state_dict(sd, strict=False)
93 | print(f"Restored from {path}")
94 |
95 |
96 |
97 | @torch.no_grad()
98 | def get_input(self, batch, bs=None):
99 | image = batch[self.cover_key]
100 | secret = batch[self.secret_key]
101 | if bs is not None:
102 | image = image[:bs]
103 | secret = secret[:bs]
104 | else:
105 | bs = image.shape[0]
106 | # encode image 1st stage
107 | image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
108 |
109 | # check if using fixed input (early training phase)
110 | # if self.training and self.fixed_input:
111 | if self.fixed_input:
112 | if self.fixed_img is None: # first iteration
113 | print('[TRAINING] Warmup - using fixed input image for now!')
114 | self.fixed_img = image.detach().clone()[:bs]
115 | self.fixed_secret = secret.detach().clone()[:bs] # use for log_images with fixed_input option only
116 | image = self.fixed_img
117 | new_bs = min(secret.shape[0], image.shape[0])
118 | image, secret = image[:new_bs], secret[:new_bs]
119 |
120 | out = [image, secret]
121 | return out
122 |
123 | def forward(self, cover, secret):
124 | # return a tuple (stego, residual)
125 | enc_out = self.encoder(cover, secret)
126 | if hasattr(self.encoder, 'return_residual') and self.encoder.return_residual:
127 | return cover + enc_out, enc_out
128 | else:
129 | return enc_out, enc_out - cover
130 |
131 |
132 |
133 | @torch.no_grad()
134 | def log_images(self, batch, fixed_input=False, **kwargs):
135 | log = dict()
136 | if fixed_input and self.fixed_img is not None:
137 | x, s = self.fixed_img, self.fixed_secret
138 | else:
139 | x, s = self.get_input(batch)
140 | stego, residual = self(x, s)
141 | if hasattr(self, 'noise') and self.noise.is_activated():
142 | img_noise = self.noise(stego, self.global_step, p=1.0)
143 | log['noised'] = img_noise
144 | log['input'] = x
145 | log['stego'] = stego
146 | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1
147 | return log
148 |
149 |
150 | def get_obj_from_str(string, reload=False):
151 | module, cls = string.rsplit(".", 1)
152 | if reload:
153 | module_imp = importlib.import_module(module)
154 | importlib.reload(module_imp)
155 | return getattr(importlib.import_module(module, package=None), cls)
156 |
157 |
158 |
159 | def instantiate_from_config(config):
160 | if not "target" in config:
161 | if config == '__is_first_stage__':
162 | return None
163 | elif config == "__is_unconditional__":
164 | return None
165 | raise KeyError("Expected key `target` to instantiate.")
166 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
167 |
168 |
169 |
--------------------------------------------------------------------------------
/python/trustmark/models/README.md:
--------------------------------------------------------------------------------
1 | Models will be fetched to this folder on first use
2 |
--------------------------------------------------------------------------------
/rust/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [alias]
2 | xtask = "run --package xtask --"
3 |
--------------------------------------------------------------------------------
/rust/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
3 | /models/*
4 | !/models/.gitkeep
5 |
6 | /env/
7 |
8 | *.py
9 |
10 | .envrc
11 |
12 | .DS_Store
13 |
14 | *.png
15 |
--------------------------------------------------------------------------------
/rust/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "trustmark"
3 | description = "A Rust implementation of TrustMark"
4 | version = "0.2.0"
5 | authors = [
6 | "John Collomosse ",
7 | "Maurice Fisher ",
8 | "Andrew Halle ",
9 | ]
10 | rust-version = "1.75.0"
11 | edition = "2021"
12 | license = "MIT"
13 | keywords = ["trustmark", "cv", "watermark"]
14 | categories = ["multimedia::images"]
15 | repository = "https://github.com/adobe/trustmark"
16 |
17 | [workspace]
18 | members = ["crates/*"]
19 |
20 | [[bench]]
21 | name = "load"
22 | harness = false
23 |
24 | [[bench]]
25 | name = "encode"
26 | harness = false
27 |
28 | [dependencies]
29 | image = "0.25.6"
30 | fast_image_resize = { version = "5.1.4", features = ["image", "rayon"] }
31 | ndarray = "0.16"
32 | ort = "=2.0.0-rc.8"
33 | thiserror = "1"
34 |
35 | [dev-dependencies]
36 | criterion = "0.5"
37 |
--------------------------------------------------------------------------------
/rust/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Adobe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/rust/README.md:
--------------------------------------------------------------------------------
1 | # TrustMark — Rust implementation
2 |
3 |
4 |
5 | An implementation in Rust of TrustMark watermarking, as described in [**TrustMark - Universal Watermarking for Arbitrary Resolution Images**](https://arxiv.org/abs/2311.18297) (`arXiv:2311.18297`) by [Tu Bui](https://www.surrey.ac.uk/people/tu-bui)[^1], [Shruti Agarwal](https://research.adobe.com/person/shruti-agarwal/)[^2], and [John Collomosse](https://www.collomosse.com)[^1] [^2].
6 |
7 | [^1]: [DECaDE](https://decade.ac.uk/) Centre for the Decentralized Digital Economy, University of Surrey, UK.
8 |
9 | [^2]: [Adobe Research](https://research.adobe.com/), San Jose, CA.
10 |
11 |
12 |
13 | This crate implements a subset of the functionality of the TrustMark Python implementation, including encoding and decoding of watermarks for all variants in binary mode. The Rust implementation provides the same levels of error correction as the Python implementation.
14 |
15 | Text mode watermarks and watermark removal are not implemented.
16 |
17 | Open an issue if there's something in the Python version that want added to this crate!
18 |
19 | ## Quick start
20 |
21 | ### Download models
22 |
23 | In order to encode or decode watermarks, you'll need to fetch the model files. The models are distributed as ONNX files.
24 |
25 | From the workspace root (the `rust/` directory), run:
26 |
27 | ```
28 | cargo xtask fetch-models
29 | ```
30 |
31 | This command downloads models to the `models/` directory. You can move them from there as needed.
32 |
33 | ### Run the CLI
34 |
35 | As a first step, you can run the `trustmark-cli` which is defined in this repository.
36 |
37 | From the workspace root, run:
38 |
39 | ```sh
40 | cargo run --release -p trustmark-cli -- -m ./models encode -i ../images/ghost.png -o ../images/encoded.png
41 | cargo run --release -p trustmark-cli -- -m ./models decode -i ../images/encoded.png
42 | ```
43 |
44 | The argument to the `-m` option is the path to the models downloaded; if you moved them, pass the relative file path as the option value.
45 |
46 | ### Use the library
47 |
48 | Add `trustmark` to your project's `cargo` manifest with:
49 |
50 | ```
51 | cargo add trustmark
52 | ```
53 |
54 | A basic example of using `trustmark` is:
55 |
56 | ```rust
57 | use trustmark::{Trustmark, Version, Variant};
58 |
59 | let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
60 | let input = image::open("../images/ghost.png").unwrap();
61 | let output = tm.encode("0010101".to_owned(), input, 0.95);
62 | ```
63 |
64 | ## Running the benchmarks
65 |
66 | ### Rust benchmarks
67 |
68 | To run the Rust benchmarks, run the following from the workspace root:
69 |
70 | ```
71 | cargo bench
72 | ```
73 |
74 | ### Python benchmarks
75 |
76 | To run the Python benchmarks, run the following from the workspace root:
77 |
78 | ```
79 | benches/load.sh && benches/encode.sh
80 | ```
81 |
--------------------------------------------------------------------------------
/rust/benches/encode.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use criterion::{criterion_group, criterion_main, Criterion};
9 | use trustmark::{Trustmark, Variant, Version};
10 |
11 | fn encode_main(c: &mut Criterion) {
12 | let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
13 | let input = image::open("../images/ufo_240.jpg").unwrap();
14 | c.bench_function("encode", |b| {
15 | b.iter(|| {
16 | let _ = tm.encode(
17 | "0100100100100001000101001010".to_owned(),
18 | input.clone(),
19 | 0.95,
20 | );
21 | })
22 | });
23 | }
24 |
25 | criterion_group!(benches, encode_main);
26 | criterion_main!(benches);
27 |
--------------------------------------------------------------------------------
/rust/benches/encode.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | # Copyright 2025 Adobe
4 | # All Rights Reserved.
5 | #
6 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
7 | # accordance with the terms of the Adobe license agreement accompanying
8 | # it.
9 |
10 | set -euxo pipefail
11 |
12 | python -m timeit \
13 | -s "from PIL import Image; from trustmark import TrustMark; tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.BCH_5); image = Image.open('../images/ufo_240.jpg').convert('RGB')" \
14 | "tm.encode(image, '0100100100100001000101001010', MODE='BINARY')"
15 |
--------------------------------------------------------------------------------
/rust/benches/load.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use criterion::{criterion_group, criterion_main, Criterion};
9 | use trustmark::{Trustmark, Variant, Version};
10 |
11 | fn load_main(c: &mut Criterion) {
12 | c.bench_function("load", |b| {
13 | b.iter(|| {
14 | let _tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
15 | })
16 | });
17 | }
18 |
19 | criterion_group!(benches, load_main);
20 | criterion_main!(benches);
21 |
--------------------------------------------------------------------------------
/rust/benches/load.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | # Copyright 2025 Adobe
4 | # All Rights Reserved.
5 | #
6 | # NOTICE: Adobe permits you to use, modify, and distribute this file in
7 | # accordance with the terms of the Adobe license agreement accompanying
8 | # it.
9 |
10 | set -euxo pipefail
11 |
12 | python -m timeit -s 'from trustmark import TrustMark' "tm=TrustMark(verbose=True, model_type='Q', encoding_type=TrustMark.Encoding.BCH_5)"
13 |
--------------------------------------------------------------------------------
/rust/crates/trustmark-cli/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "trustmark-cli"
3 | description = "A CLI for trustmark encoding/decoding"
4 | version = "0.1.0"
5 | authors = [
6 | "John Collomosse ",
7 | "Maurice Fisher ",
8 | "Andrew Halle ",
9 | ]
10 | license = "MIT"
11 | keywords = ["trustmark", "cv", "watermark", "cli"]
12 | edition = "2021"
13 | publish = false
14 |
15 | [[bin]]
16 | name = "trustmark"
17 | path = "src/main.rs"
18 |
19 | [dependencies]
20 | clap = { version = "4.5.20", features = ["derive"] }
21 | image = "0.25.6"
22 | rand = "0.8.5"
23 | trustmark = { path = "../.." }
24 |
--------------------------------------------------------------------------------
/rust/crates/trustmark-cli/README.md:
--------------------------------------------------------------------------------
1 | # TrustMark CLI
2 |
3 | The Rust implementation includes a CLI wrapper for the `trustmark` crate.
4 |
5 | ## Installation
6 |
7 | To install the CLI, run this command from the `trustmark/crates/trustmark-cli` directory:
8 |
9 | ```
10 | cargo install --locked --path .
11 | ```
12 |
13 | ### Downloading models
14 |
15 | To use the CLI, you must first create a `models` directory and download the models, [if you haven't already done so](../../README.md#download-models). Enter these commands from the `trustmark/rust` directory:
16 |
17 | ```
18 | mkdir models
19 | cargo xtask fetch-models
20 | ```
21 |
22 | ## Usage
23 |
24 | View CLI help information by entering this command:
25 |
26 | ```
27 | trustmark [encode | decode] help
28 | ```
29 |
30 | The basic command syntax is:
31 |
32 | ```
33 | trustmark --models [encode | decode]
34 | ```
35 |
36 | Where `` is the relative path to the directory containing models.
37 | Use the `encode` subcommand to encode a watermark into an image and the `decode` subcommand to decode a watermark from an image.
38 |
39 | ### Encoding watermarks
40 |
41 | To encode a watermark into an image, use the `encode` subcommand:
42 |
43 | ```
44 | trustmark --models encode [OPTIONS] -i -o
45 | ```
46 |
47 | Options:
48 |
49 | | Option | Description | Allowed Values |
50 | |--------|--------------|----------------|
51 | | `-i ` | Path to the image to encode. | Relative file path. |
52 | | `-o ` | Path to file in which to save the watermarked image. | Relative file path. |
53 | | `-w, --watermark ` | The watermark (payload) to encode. | Any a binary string such as `0101010101`. Only 0 and 1 characters are allowed. Maximum length is governed by the version selected. Default is a random binary string. |
54 | | `--version ` | The BCH version to encode with. | One of `BCH_SUPER` (default), `BCH_5`, `BCH_4`, or `BCH_3`. |
55 | | `--variant ` | The model variant to encode with. | `Q` (default), `B`, `C`, and `P`. |
56 | | `--quality ` | If the requested output format is JPEG, the output quality to encode. | A number between 0 and 100. The default is 90. |
57 | | `-h, --help` | Display help information. | N/A |
58 |
59 | ### Decoding watermarks
60 |
61 | To decode a watermark from an image, use the `decode` subcommand:
62 |
63 | ```
64 | trustmark --models decode [OPTIONS] -i
65 | ```
66 |
67 | | Option | Description | Allowed Values |
68 | |--------|--------------|----------------|
69 | | `-i ` | Path to the image to decode. | Relative file path. |
70 | | `--variant ` | The model variant to decode with. Must match variant used to encode the watermark. | `Q` (default), `B`, `C`, and `P`. |
71 | | `-h, --help` | Display help information. | N/A |
72 |
73 | ## Examples
74 |
75 | To encode a watermark into one of the sample images, run this command from the workspace root:
76 |
77 | ```sh
78 | trustmark -m ./models encode -i ../images/ghost.png -o ../images/ghost_encoded.png
79 | ```
80 |
81 | Then to decode the watermark from this image, run this command from the workspace root:
82 |
83 | ```sh
84 | trustmark -m ./models decode -i ../images/ghost_encoded.png
85 | ```
86 |
87 | You'll see something like this in your terminal:
88 |
89 | ```
90 | Found watermark: 0101111001101001000000011011010100100010011101101101000001101
91 | ```
92 |
--------------------------------------------------------------------------------
/rust/crates/trustmark-cli/src/main.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2024 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use std::{fs::OpenOptions, path::PathBuf};
9 |
10 | use clap::{Parser, Subcommand};
11 | use image::{codecs::jpeg::JpegEncoder, ImageFormat};
12 | use rand::{distributions::Standard, prelude::Distribution as _};
13 | use trustmark::{Trustmark, Variant, Version};
14 |
15 | #[derive(Debug, Parser)]
16 | struct Args {
17 | #[arg(short, long)]
18 | models: PathBuf,
19 | #[command(subcommand)]
20 | command: Command,
21 | }
22 |
23 | #[derive(Debug, Subcommand)]
24 | enum Command {
25 | /// Encode a watermark into an image
26 | Encode {
27 | /// The image to encode.
28 | #[arg(short)]
29 | input: PathBuf,
30 | /// The path to save the watermarked image.
31 | #[arg(short)]
32 | output: PathBuf,
33 | /// The watermark to encode. Defaults to random if not specified.
34 | #[arg(short, long)]
35 | watermark: Option,
36 | /// The BCH version to encode with. Defaults to BchSuper.
37 | #[arg(long)]
38 | version: Option,
39 | /// The model variant to encode with.
40 | #[arg(long)]
41 | variant: Option,
42 | /// If the requested output is JPEG, the quality to use for encoding.
43 | #[arg(long)]
44 | quality: Option,
45 | },
46 | /// Decode a watermark from an image
47 | Decode {
48 | #[arg(short)]
49 | input: PathBuf,
50 | /// The model variant to decode with.
51 | #[arg(long)]
52 | variant: Option,
53 | },
54 | }
55 |
56 | impl Command {
57 | /// Extract the version to use from this `Command`.
58 | fn get_version(&self) -> Version {
59 | match self {
60 | Command::Encode {
61 | version: Some(version),
62 | ..
63 | } => *version,
64 | _ => Version::Bch5,
65 | }
66 | }
67 |
68 | /// Extract the variant to use from this `Command`.
69 | fn get_variant(&self) -> Variant {
70 | match self {
71 | Command::Encode {
72 | variant: Some(variant),
73 | ..
74 | } => *variant,
75 | Command::Decode {
76 | variant: Some(variant),
77 | ..
78 | } => *variant,
79 | _ => Variant::Q,
80 | }
81 | }
82 | }
83 |
84 | /// Generate a random watermark with as many bits as specified by `bits`.
85 | ///
86 | /// # Example
87 | ///
88 | /// ```rust
89 | /// # fn main() {
90 | /// println!("{}", gen_watermark(4)); // 1010
91 | /// # }
92 | /// ```
93 | fn gen_watermark(bits: usize) -> String {
94 | let mut rng = rand::thread_rng();
95 | let v: Vec = Standard.sample_iter(&mut rng).take(bits).collect();
96 | v.into_iter()
97 | .map(|bit| if bit { '1' } else { '0' })
98 | .collect()
99 | }
100 |
101 | fn main() {
102 | let args = Args::parse();
103 | let tm = Trustmark::new(
104 | &args.models,
105 | args.command.get_variant(),
106 | args.command.get_version(),
107 | )
108 | .unwrap();
109 | match args.command {
110 | Command::Encode {
111 | input,
112 | output,
113 | watermark,
114 | version,
115 | quality,
116 | ..
117 | } => {
118 | let input = image::open(input).unwrap();
119 | let watermark = watermark.unwrap_or_else(|| {
120 | gen_watermark(version.unwrap_or(Version::Bch5).data_bits().into())
121 | });
122 | let encoded = tm.encode(watermark.clone(), input, 0.95).unwrap();
123 |
124 | let format = ImageFormat::from_path(&output).unwrap();
125 | match format {
126 | // JPEG encoding can make visual artifacts worse, so we encode with a higher
127 | // quality than the default (or the quality requested by the user).
128 | ImageFormat::Jpeg => {
129 | let quality = quality.unwrap_or(90);
130 | let mut writer = OpenOptions::new()
131 | .write(true)
132 | .create(true)
133 | .truncate(true)
134 | .open(&output)
135 | .unwrap();
136 | let encoder = JpegEncoder::new_with_quality(&mut writer, quality);
137 | encoded.to_rgb8().write_with_encoder(encoder).unwrap();
138 | }
139 | _ => {
140 | encoded.to_rgba8().save(&output).unwrap();
141 | }
142 | }
143 | }
144 | Command::Decode { input, .. } => {
145 | let input = image::open(input).unwrap();
146 | match tm.decode(input) {
147 | Ok(decoded) => println!("Found watermark: {decoded}"),
148 | Err(trustmark::Error::CorruptWatermark) => {
149 | println!("Corrupt or missing watermark")
150 | }
151 | err => panic!("{err:?}"),
152 | }
153 | }
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/rust/crates/xtask/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "xtask"
3 | version = "0.1.0"
4 | edition = "2021"
5 | publish = false
6 |
7 | [dependencies]
8 | clap = { version = "4.5.20", features = ["derive"] }
9 | ureq = "2.10.1"
10 |
--------------------------------------------------------------------------------
/rust/crates/xtask/src/main.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2024 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use std::{
9 | fs::OpenOptions,
10 | io::{self, Read as _},
11 | };
12 |
13 | use clap::Parser;
14 |
15 | #[derive(Debug, Parser)]
16 | enum Args {
17 | FetchModels,
18 | }
19 |
20 | fn main() {
21 | let args = Args::parse();
22 | match args {
23 | Args::FetchModels => fetch_models(),
24 | }
25 | }
26 |
27 | /// Fetch all known models.
28 | fn fetch_models() {
29 | fetch_model("decoder_Q.onnx");
30 | fetch_model("encoder_Q.onnx");
31 | fetch_model("decoder_P.onnx");
32 | fetch_model("encoder_P.onnx");
33 | fetch_model("decoder_B.onnx");
34 | fetch_model("encoder_B.onnx");
35 | fetch_model("decoder_C.onnx");
36 | fetch_model("encoder_C.onnx");
37 | }
38 |
39 | /// Fetch a single model identified by `filename`.
40 | ///
41 | /// Models are fetched from a hardcoded CDN URL.
42 | fn fetch_model(filename: &str) {
43 | let root = "https://cc-assets.netlify.app/watermarking/trustmark-models";
44 | let model_url = format!("{root}/{filename}",);
45 | let mut decoder = ureq::get(&model_url)
46 | .call()
47 | .unwrap()
48 | .into_reader()
49 | .take(100_000_000);
50 | let mut file = OpenOptions::new()
51 | .write(true)
52 | .create_new(true)
53 | .open(format!("models/{filename}"))
54 | .unwrap();
55 | io::copy(&mut decoder, &mut file).unwrap();
56 | }
57 |
--------------------------------------------------------------------------------
/rust/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/trustmark/31bad43c58decbf7286d63f59c80b5678b4071a7/rust/models/.gitkeep
--------------------------------------------------------------------------------
/rust/src/bits.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2024 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use std::{fmt::Display, str::FromStr};
9 |
10 | use ndarray::{Array1, ArrayD, Axis};
11 | use ort::{TensorValueType, Value};
12 |
13 | const VERSION_BITS: u16 = 4;
14 |
15 | mod bch;
16 |
17 | #[derive(Debug)]
18 | pub(super) struct Bits(String);
19 |
20 | /// Error type for the `bits` module.
21 | #[derive(Debug, thiserror::Error)]
22 | pub enum Error {
23 | /// Something went wrong while doing inference.
24 | #[error("onnx error: {0}")]
25 | Ort(#[from] ort::Error),
26 |
27 | /// A character was encounted that was not a '0' or a '1'. Strings that specify bitstrings
28 | /// should only use these two characters.
29 | #[error("allowed chars are '0' and '1'")]
30 | InvalidChar,
31 |
32 | /// A bitstring has too many bits to fit in the requested version's data payload.
33 | #[error(
34 | "input bitstring ({bits} bits) has more bits than version allows ({version_allows} bits)"
35 | )]
36 | InvalidDataLength { version_allows: usize, bits: usize },
37 |
38 | /// A bitstring of a certain length is required, and a different length was encountered.
39 | #[error("must be of length 100")]
40 | InvalidLength,
41 |
42 | /// Model input/output did not have the expected dimensions.
43 | #[error("invalid dimensions")]
44 | InvalidDim,
45 |
46 | /// String does not represent a known version.
47 | #[error("invalid version")]
48 | InvalidVersion,
49 |
50 | /// Watermark is missing or corrupted.
51 | ///
52 | /// Either the image did not have a valid watermark, or too many transmission errors/image
53 | /// artifacts occurred such that the watermark is no longer recoverable.
54 | #[error("corrupt watermark")]
55 | CorruptWatermark,
56 | }
57 |
58 | impl Bits {
59 | /// Constructs a `Bits`, adding in the additional error correction and schema bits.
60 | pub(super) fn apply_error_correction_and_schema(
61 | mut input: String,
62 | version: Version,
63 | ) -> Result {
64 | let data_bits: usize = version.data_bits().into();
65 |
66 | if input.chars().any(|c| c != '0' && c != '1') {
67 | return Err(Error::InvalidChar);
68 | }
69 |
70 | if input.len() > data_bits {
71 | return Err(Error::InvalidDataLength {
72 | bits: input.len(),
73 | version_allows: data_bits,
74 | });
75 | }
76 |
77 | // pad the input
78 | input.push_str(&"0".repeat(data_bits - input.len() + (8 - data_bits % 8)));
79 |
80 | // pack the input into bytes
81 | let data: Vec = input
82 | .as_bytes()
83 | .chunks(8)
84 | .map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 2).unwrap())
85 | .collect();
86 |
87 | // calculate the error correction bits
88 | let mut ecc_state = bch::bch_init(version.allowed_bit_flips() as u32, bch::POLYNOMIAL);
89 | let ecc = bch::bch_encode(&mut ecc_state, &data);
90 |
91 | // form a bitstring from the error correction bits
92 | let mut error_correction: String = ecc
93 | .iter()
94 | .map(|byte| format!("{byte:08b}"))
95 | .collect::>()
96 | .join("");
97 |
98 | // split off unneeded padding
99 | input.truncate(data_bits);
100 | error_correction.truncate(version.ecc_bits().into());
101 |
102 | // form the encoded string
103 | Ok(Self(format!(
104 | "{input}{error_correction}{}",
105 | version.bitstring()
106 | )))
107 | }
108 |
109 | /// Get the data out of a `Bits` by removing the error correction bits.
110 | pub(super) fn get_data(self) -> String {
111 | let version = self.get_version();
112 | let Self(mut s) = self;
113 | s.truncate(version.data_bits().into());
114 | s
115 | }
116 |
117 | /// Get the version from the bits.
118 | pub(super) fn get_version(&self) -> Version {
119 | match &self.0[98..100] {
120 | "00" => Version::BchSuper,
121 | "01" => Version::Bch5,
122 | "10" => Version::Bch4,
123 | "11" => Version::Bch3,
124 | _ => unreachable!(),
125 | }
126 | }
127 |
128 | /// Construct a `Bits` from a bitstring.
129 | ///
130 | /// This function checks for bitflips in the bitstring using the error-correcting bits, and
131 | /// corrects them if there are fewer bitflips than are supported by the version. As a last
132 | /// resort, this function checks for bitflips in the version identifier by trying all possible
133 | /// versions.
134 | fn new(s: String) -> Result {
135 | if s.chars().any(|c| c != '0' && c != '1') {
136 | return Err(Error::InvalidChar);
137 | }
138 |
139 | if s.len() != 100 {
140 | return Err(Error::InvalidLength);
141 | }
142 |
143 | let version: Version = Version::from_bitstring(&s[96..]).unwrap_or_default();
144 |
145 | if let Ok(bits) = Bits::new_with_version(&s, version) {
146 | Ok(bits)
147 | } else {
148 | let mut versions = vec![
149 | Version::Bch3,
150 | Version::Bch4,
151 | Version::Bch5,
152 | Version::BchSuper,
153 | ];
154 | versions.retain(|v| *v != version);
155 | let mut res = None;
156 | for version in versions {
157 | res = Some(Bits::new_with_version(&s, version));
158 | if res.as_ref().unwrap().is_ok() {
159 | return res.unwrap();
160 | }
161 | }
162 | res.unwrap()
163 | }
164 | }
165 |
166 | fn new_with_version(s: &str, version: Version) -> Result {
167 | let data_bits: usize = version.data_bits().into();
168 | let ecc_bits: usize = version.ecc_bits().into();
169 |
170 | let mut data = s[..data_bits].to_string();
171 | let mut ecc = s[data_bits..data_bits + ecc_bits].to_string();
172 |
173 | // pad
174 | data.push_str(&"0".repeat(data_bits - data.len() + (8 - data_bits % 8)));
175 | ecc.push_str(&"0".repeat(ecc_bits - ecc.len() + (8 - ecc_bits % 8)));
176 |
177 | // pack into bytes
178 | let mut data: Vec = data
179 | .as_bytes()
180 | .chunks(8)
181 | .map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 2).unwrap())
182 | .collect();
183 | let ecc: Vec = ecc
184 | .as_bytes()
185 | .chunks(8)
186 | .map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 2).unwrap())
187 | .collect();
188 |
189 | // validate and correct
190 | let mut ecc_state = bch::bch_init(version.allowed_bit_flips() as u32, bch::POLYNOMIAL);
191 | let bitflips = bch::bch_decode(&mut ecc_state, &mut data, &ecc);
192 |
193 | if bitflips > version.allowed_bit_flips() {
194 | return Err(Error::CorruptWatermark);
195 | }
196 |
197 | // unpack data and ecc
198 | let mut data: String = data
199 | .iter()
200 | .map(|byte| format!("{byte:08b}"))
201 | .collect::>()
202 | .join("");
203 | let mut ecc: String = ecc
204 | .iter()
205 | .map(|byte| format!("{byte:08b}"))
206 | .collect::>()
207 | .join("");
208 | data.truncate(data_bits);
209 | ecc.truncate(ecc_bits);
210 |
211 | Ok(Bits(format!("{data}{ecc}{}", version.bitstring())))
212 | }
213 | }
214 |
215 | impl From for ort::Value> {
216 | fn from(Bits(s): Bits) -> Self {
217 | let floats: Vec = s
218 | .chars()
219 | .map(|c| match c {
220 | '0' => 0.0,
221 | '1' => 1.0,
222 | _ => unreachable!(),
223 | })
224 | .collect();
225 |
226 | let array = Array1::from(floats);
227 | Value::from_array(array.insert_axis(Axis(0))).unwrap()
228 | }
229 | }
230 |
231 | impl TryFrom> for Bits {
232 | type Error = Error;
233 |
234 | fn try_from(array: ArrayD) -> Result {
235 | if array.shape() != [1, 100] {
236 | return Err(Error::InvalidDim);
237 | }
238 | let array = array.remove_axis(Axis(0));
239 | let mut s = String::new();
240 | for bit in array.iter() {
241 | let c = if *bit < 0. { '0' } else { '1' };
242 | s.push(c);
243 | }
244 |
245 | Bits::new(s)
246 | }
247 | }
248 |
249 | /// The error correction schema
250 | #[derive(Debug, Default, Copy, Clone, PartialEq)]
251 | pub enum Version {
252 | /// Tolerates 8 bit flips
253 | #[default]
254 | BchSuper,
255 | /// Tolerates 5 bit flips
256 | Bch5,
257 | /// Tolerates 4 bit flips
258 | Bch4,
259 | /// Tolerates 3 bit flips
260 | Bch3,
261 | }
262 |
263 | impl Display for Version {
264 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 | let s = match self {
266 | Version::BchSuper => "BCH_SUPER",
267 | Version::Bch5 => "BCH_5",
268 | Version::Bch4 => "BCH_4",
269 | Version::Bch3 => "BCH_3",
270 | };
271 | write!(f, "{s}")
272 | }
273 | }
274 |
275 | impl FromStr for Version {
276 | type Err = Error;
277 |
278 | fn from_str(s: &str) -> Result {
279 | let version = match s {
280 | "BCH_SUPER" => Version::BchSuper,
281 | "BCH_5" => Version::Bch5,
282 | "BCH_4" => Version::Bch4,
283 | "BCH_3" => Version::Bch3,
284 | _ => return Err(Error::InvalidVersion),
285 | };
286 |
287 | Ok(version)
288 | }
289 | }
290 |
291 | impl Version {
292 | /// Get the number of allowed bit flips for this version.
293 | fn allowed_bit_flips(&self) -> u8 {
294 | match self {
295 | Version::BchSuper => 8,
296 | Version::Bch5 => 5,
297 | Version::Bch4 => 4,
298 | Version::Bch3 => 3,
299 | }
300 | }
301 |
302 | /// Get the number of data bits for this version.
303 | pub fn data_bits(&self) -> u16 {
304 | match self {
305 | Version::BchSuper => 40,
306 | Version::Bch5 => 61,
307 | Version::Bch4 => 68,
308 | Version::Bch3 => 75,
309 | }
310 | }
311 |
312 | /// Get the bitstring which indicates this version.
313 | fn bitstring(&self) -> String {
314 | match self {
315 | Version::BchSuper => "0000".to_owned(),
316 | Version::Bch5 => "0001".to_owned(),
317 | Version::Bch4 => "0010".to_owned(),
318 | Version::Bch3 => "0011".to_owned(),
319 | }
320 | }
321 |
322 | /// Parse a version from a bitstring.
323 | fn from_bitstring(s: &str) -> Result {
324 | Ok(match s {
325 | "0000" => Version::BchSuper,
326 | "0001" => Version::Bch5,
327 | "0010" => Version::Bch4,
328 | "0011" => Version::Bch3,
329 | _ => return Err(Error::InvalidVersion),
330 | })
331 | }
332 |
333 | /// Get the number of error correcting bits for this version.
334 | fn ecc_bits(&self) -> u16 {
335 | 100 - VERSION_BITS - self.data_bits()
336 | }
337 | }
338 |
339 | #[cfg(test)]
340 | mod tests {
341 | use super::*;
342 |
343 | #[test]
344 | fn get_version() {
345 | let input = "1011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned();
346 | let bits = Bits(input);
347 | assert_eq!(bits.get_version(), Version::Bch5);
348 | }
349 |
350 | #[test]
351 | fn get_data() {
352 | let input = "1011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned();
353 | let bits = Bits(input);
354 | assert_eq!(
355 | bits.get_data(),
356 | "1011011110011000111111000000011111011111011100000110110110111"
357 | );
358 | }
359 |
360 | #[test]
361 | fn new() {
362 | let input = "1011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned();
363 | let bits = Bits::new(input).unwrap();
364 | assert_eq!(
365 | bits.get_data(),
366 | "1011011110011000111111000000011111011111011100000110110110111"
367 | );
368 | }
369 |
370 | #[test]
371 | fn fully_corrupted() {
372 | let input = "0000000000000000000000000000000000000000000100000110110110111000110010101101111010011011000010000001".to_owned();
373 | let err = Bits::new(input).unwrap_err();
374 | assert_eq!(err.to_string(), "corrupt watermark");
375 | }
376 |
377 | #[test]
378 | fn single_bitflip() {
379 | let input = "0011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000001".to_owned();
380 | let bits = Bits::new(input).unwrap();
381 | assert_eq!(
382 | bits.get_data(),
383 | "1011011110011000111111000000011111011111011100000110110110111"
384 | );
385 | }
386 |
387 | #[test]
388 | fn single_bitflip_and_corrupted_version() {
389 | let input = "0011011110011000111111000000011111011111011100000110110110111000110010101101111010011011000010000011".to_owned();
390 | let bits = Bits::new(input).unwrap();
391 | assert_eq!(
392 | bits.get_data(),
393 | "1011011110011000111111000000011111011111011100000110110110111"
394 | );
395 | }
396 |
397 | #[test]
398 | fn invalid_bitstring() {
399 | let err = Bits::apply_error_correction_and_schema("hello".to_string(), Version::Bch5)
400 | .unwrap_err();
401 | assert!(matches!(err, Error::InvalidChar));
402 | }
403 |
404 | #[test]
405 | fn too_long_input() {
406 | let err =
407 | Bits::apply_error_correction_and_schema("0".repeat(200), Version::Bch5).unwrap_err();
408 | assert!(matches!(err, Error::InvalidDataLength { .. }));
409 | }
410 |
411 | #[test]
412 | fn corrupt() {
413 | let err = Bits::new("1".repeat(100)).unwrap_err();
414 | assert!(matches!(err, Error::CorruptWatermark));
415 | }
416 |
417 | #[test]
418 | fn invalid_dim() {
419 | let ar = ArrayD::::zeros(ndarray::IxDyn(&[3]));
420 | let res: Result = ar.try_into();
421 | assert!(matches!(res.unwrap_err(), Error::InvalidDim));
422 | }
423 | }
424 |
--------------------------------------------------------------------------------
/rust/src/image_processing.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2024 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use std::cmp;
9 |
10 | use fast_image_resize::{ResizeAlg, ResizeOptions, Resizer};
11 | use image::{
12 | imageops::{self, FilterType},
13 | DynamicImage, GenericImageView as _, GrayAlphaImage, GrayImage, ImageBuffer, Pixel as _,
14 | Rgb32FImage, RgbImage, Rgba32FImage, RgbaImage,
15 | };
16 | use ndarray::{s, Array, ArrayD, Axis, ShapeError};
17 | use ort::TensorValueType;
18 |
19 | use crate::Variant;
20 |
21 | /// Re-normalize a floating point value (either scalar or array) from the range [0,1] to the range
22 | /// [-1, 1].
23 | macro_rules! convert_from_0_1_to_neg1_1 {
24 | ($f:expr) => {
25 | $f * 2. - 1.
26 | };
27 | }
28 |
29 | /// Re-normalize a floating point value (either scalar or array) from the range [-1, 1] to the
30 | /// range [0, 1].
31 | macro_rules! convert_from_neg1_1_to_0_1 {
32 | ($f:expr) => {
33 | ($f + 1.) / 2.
34 | };
35 | }
36 |
37 | pub(super) struct ModelImage(pub(super) u32, pub(super) Variant, pub(super) DynamicImage);
38 |
39 | /// The error type for the `image_processing` module.
40 | #[derive(Debug, thiserror::Error)]
41 | pub enum Error {
42 | /// Something went wrong during inference.
43 | #[error("onnx error: {0}")]
44 | Ort(#[from] ort::Error),
45 |
46 | /// We were unable to make an `ndarray::Array` of the requested shape.
47 | #[error("shape error: {0}")]
48 | Shape(#[from] ShapeError),
49 |
50 | /// The input array has an unexpected shape.
51 | #[error("invalid shape")]
52 | InvalidShape,
53 |
54 | // We could not create an `ImageBuffer` with the requested array.
55 | #[error("invalid image")]
56 | Image,
57 |
58 | /// We were unable to resize the input image.
59 | #[error("resize error: {0}")]
60 | Resize(#[from] fast_image_resize::ResizeError),
61 | }
62 |
63 | impl TryFrom for ort::Value> {
64 | type Error = Error;
65 |
66 | fn try_from(ModelImage(size, variant, img): ModelImage) -> Result {
67 | let (w, h, xpos, ypos) = center_crop_size_and_offset(variant, &img);
68 |
69 | let options = ResizeOptions::new()
70 | .crop(xpos as f64, ypos as f64, w as f64, h as f64)
71 | .resize_alg(ResizeAlg::Interpolation(
72 | fast_image_resize::FilterType::Bilinear,
73 | ));
74 | let modified_img = resize_img(&img, size, size, options)?;
75 |
76 | let img = modified_img.into_rgb32f().into_vec();
77 | let array = Array::from(img);
78 |
79 | // The `image` crate normalizes to `[0,1]`. Trustmark wants images normalized to `[-1,1]`.
80 | let array = convert_from_0_1_to_neg1_1!(array);
81 |
82 | let mut array = array
83 | .to_shape([size as usize, size as usize, 3])?
84 | .insert_axis(Axis(3))
85 | .reversed_axes();
86 | array.swap_axes(2, 3);
87 | assert_eq!(array.shape(), &[1, 3, size as usize, size as usize]);
88 | Ok(ort::Value::from_array(&array)?)
89 | }
90 | }
91 |
92 | impl TryFrom<(u32, Variant, ArrayD)> for ModelImage {
93 | type Error = Error;
94 |
95 | fn try_from(
96 | (size, variant, mut array): (u32, Variant, ArrayD),
97 | ) -> Result {
98 | let &[1, 3, height, width] = &array.shape().to_owned()[..] else {
99 | return Err(Error::InvalidShape);
100 | };
101 | array.swap_axes(2, 3);
102 | let array = array.reversed_axes().remove_axis(Axis(3));
103 | let array = array.to_shape([width * height * 3])?;
104 |
105 | // The `image` crate normalizes to `[0,1]`. Trustmark wants images normalized to `[-1,1]`.
106 | let array = convert_from_neg1_1_to_0_1!(array);
107 |
108 | let image = Rgb32FImage::from_vec(width as u32, height as u32, array.to_vec())
109 | .ok_or(Error::Image)?;
110 |
111 | Ok(Self(size, variant, image.into()))
112 | }
113 | }
114 |
115 | /// Apply `residual` to the `input`.
116 | ///
117 | /// This function upscales `residual` to be the size of of `input`, then adds `residual` to the
118 | /// `input`.
119 | pub(super) fn apply_residual(input: DynamicImage, residual: DynamicImage) -> DynamicImage {
120 | let has_alpha = input.color().has_alpha();
121 | let (w, h) = input.dimensions();
122 |
123 | let applied = {
124 | let input = input.clone().into_rgba32f();
125 | let mut target = input.clone();
126 |
127 | let residual = residual.resize_exact(w, h, FilterType::Triangle);
128 | let residual = residual.into_rgba32f();
129 |
130 | for ((target, residual), original) in target
131 | .pixels_mut()
132 | .zip(residual.pixels())
133 | .zip(input.pixels())
134 | {
135 | target.apply2(residual, |x, y| {
136 | let x = convert_from_0_1_to_neg1_1!(x);
137 | let y = convert_from_0_1_to_neg1_1!(y);
138 |
139 | convert_from_neg1_1_to_0_1!(f32::min(x + y, 1.0))
140 | });
141 | target[3] = original[3];
142 | }
143 |
144 | target
145 | };
146 |
147 | if has_alpha {
148 | let mut input = input.into_rgba32f();
149 | imageops::replace(&mut input, &applied, 0, 0);
150 | input.into()
151 | } else {
152 | let mut input = input.into_rgb32f();
153 | let applied = DynamicImage::ImageRgba32F(applied).into_rgb32f();
154 | imageops::replace(&mut input, &applied, 0, 0);
155 | input.into()
156 | }
157 | }
158 |
159 | /// Return the size and offset of the "center-cropped" image.
160 | ///
161 | /// Returns `(width, height, xpos, ypos)` for the square to crop.
162 | ///
163 | /// For long-skinny images or short-wide images, we want to crop a square image with side length of
164 | /// the shorter side out of the center of the image for the model.
165 | fn center_crop_size_and_offset(variant: Variant, img: &DynamicImage) -> (u32, u32, u32, u32) {
166 | let (width, height) = img.dimensions();
167 |
168 | if height > width * 2 || width > height * 2 || variant == Variant::P {
169 | let m = cmp::min(height, width);
170 | let offset = (cmp::max(height, width) - m) / 2;
171 |
172 | let xpos;
173 | let ypos;
174 | if height > width {
175 | xpos = 0;
176 | ypos = offset;
177 | } else {
178 | ypos = 0;
179 | xpos = offset;
180 | }
181 |
182 | (m, m, xpos, ypos)
183 | } else {
184 | (width, height, 0, 0)
185 | }
186 | }
187 |
188 | /// Returns a new `DynamicImage`, resized to `width` by `height` with the specified `ResizeOptions`.
189 | fn resize_img(
190 | img: &DynamicImage,
191 | width: u32,
192 | height: u32,
193 | options: ResizeOptions,
194 | ) -> Result {
195 | let mut modified_img = match img {
196 | DynamicImage::ImageLuma8(_) => DynamicImage::ImageLuma8(GrayImage::new(width, height)),
197 | DynamicImage::ImageLumaA8(_) => {
198 | DynamicImage::ImageLumaA8(GrayAlphaImage::new(width, height))
199 | }
200 | DynamicImage::ImageRgb8(_) => DynamicImage::ImageRgb8(RgbImage::new(width, height)),
201 | DynamicImage::ImageRgba8(_) => DynamicImage::ImageRgba8(RgbaImage::new(width, height)),
202 | DynamicImage::ImageLuma16(_) => DynamicImage::ImageLuma16(ImageBuffer::new(width, height)),
203 | DynamicImage::ImageLumaA16(_) => {
204 | DynamicImage::ImageLumaA16(ImageBuffer::new(width, height))
205 | }
206 | DynamicImage::ImageRgb16(_) => DynamicImage::ImageRgb16(ImageBuffer::new(width, height)),
207 | DynamicImage::ImageRgba16(_) => DynamicImage::ImageRgba16(ImageBuffer::new(width, height)),
208 | DynamicImage::ImageRgb32F(_) => DynamicImage::ImageRgb32F(Rgb32FImage::new(width, height)),
209 | DynamicImage::ImageRgba32F(_) => {
210 | DynamicImage::ImageRgba32F(Rgba32FImage::new(width, height))
211 | }
212 | // Technically unreachable, but we error for safety.
213 | _ => return Err(Error::Image),
214 | };
215 | Resizer::new().resize(img, &mut modified_img, &options)?;
216 |
217 | Ok(modified_img)
218 | }
219 |
220 | /// Applies the mean padding boundary artifact mitigation.
221 | ///
222 | /// Center cropped images have a vertical line problem along the boundary of the residual. This
223 | /// transformation makes this boundary less visible.
224 | pub(super) fn remove_boundary_artifact(
225 | mut residual: ArrayD,
226 | (width, height): (usize, usize),
227 | _variant: Variant,
228 | ) -> ArrayD {
229 | // We're going to replace the border of the residual with the mean and also pad the non-center
230 | // areas with the mean value.
231 | let channel_means: Vec = (0_usize..3)
232 | .map(|i| residual.slice(s![.., i, .., ..]).mean().unwrap())
233 | .collect();
234 |
235 | // We want one dimension of the output to be 256 and we we want the aspect ratio of the output
236 | // to match the input image.
237 | let mut mean_padded: ndarray::Array4 = if width > height {
238 | let other = ((width as f32 / height as f32) * 256.0) as usize;
239 | ndarray::Array4::zeros([1, 3, 256_usize, other])
240 | } else {
241 | let other = (height / width) * 256;
242 | ndarray::Array4::zeros([1, 3, other, 256])
243 | };
244 |
245 | // This softens the transition between the residual area and the rest of the image.
246 | let border = 2;
247 | for (i, mean) in channel_means.iter().enumerate() {
248 | residual.slice_mut(s![0, i, ..border, ..]).fill(*mean);
249 | residual.slice_mut(s![0, i, -border.., ..]).fill(*mean);
250 | residual.slice_mut(s![0, i, .., -border..]).fill(*mean);
251 | residual.slice_mut(s![0, i, .., ..border]).fill(*mean);
252 | mean_padded.slice_mut(s![0, i, .., ..]).fill(*mean);
253 | }
254 |
255 | if width > height {
256 | let other = ((width as f32 / height as f32) * 256.0) as usize;
257 | let leftover = (other - 256) / 2;
258 | mean_padded
259 | .slice_mut(s![.., .., .., leftover..(leftover + 256)])
260 | .assign(&residual);
261 | } else {
262 | let other = ((height as f32 / width as f32) * 256.0) as usize;
263 | let leftover = (other - 256) / 2;
264 | mean_padded
265 | .slice_mut(s![.., .., leftover..(leftover + 256), ..])
266 | .assign(&residual);
267 | }
268 |
269 | mean_padded.into_dyn()
270 | }
271 |
272 | #[cfg(test)]
273 | mod tests {
274 | use super::*;
275 |
276 | #[test]
277 | fn renormalize_from_0_1() {
278 | assert_eq!(convert_from_0_1_to_neg1_1!(0.), -1.);
279 | assert_eq!(convert_from_0_1_to_neg1_1!(0.5), 0.);
280 | assert_eq!(convert_from_0_1_to_neg1_1!(0.99), 0.98);
281 | }
282 |
283 | #[test]
284 | fn renormalize_from_neg1_1() {
285 | assert_eq!(convert_from_neg1_1_to_0_1!(-1.), 0.);
286 | assert_eq!(convert_from_neg1_1_to_0_1!(0.5), 0.75);
287 | assert_eq!(convert_from_neg1_1_to_0_1!(-0.1), 0.45);
288 | }
289 |
290 | #[test]
291 | fn normal_image() {
292 | let image = DynamicImage::new(100, 110, image::ColorType::L8);
293 | assert_eq!(
294 | center_crop_size_and_offset(Variant::Q, &image),
295 | (100, 110, 0, 0)
296 | );
297 | }
298 |
299 | #[test]
300 | fn skinny_image() {
301 | let image = DynamicImage::new(10, 100, image::ColorType::L8);
302 | assert_eq!(
303 | center_crop_size_and_offset(Variant::Q, &image),
304 | (10, 10, 0, 45)
305 | );
306 | }
307 |
308 | #[test]
309 | fn wide_image() {
310 | let image = DynamicImage::new(101, 10, image::ColorType::L8);
311 | assert_eq!(
312 | center_crop_size_and_offset(Variant::Q, &image),
313 | (10, 10, 45, 0)
314 | );
315 | }
316 |
317 | #[test]
318 | fn always_crop_p() {
319 | let image = DynamicImage::new(100, 110, image::ColorType::L8);
320 | assert_eq!(
321 | center_crop_size_and_offset(Variant::P, &image),
322 | (100, 100, 0, 5)
323 | );
324 | }
325 | }
326 |
--------------------------------------------------------------------------------
/rust/src/lib.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2024 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | //! # Trustmark
9 | //!
10 | //! An implementation of TrustMark watermarking for the Content Authenticity Initiative (CAI) in
11 | //! Rust, as described in:
12 | //!
13 | //! ---
14 | //!
15 | //! **TrustMark - Universal Watermarking for Arbitrary Resolution Images**
16 | //!
17 | //!
18 | //!
19 | //! [Tu Bui]1 , [Shruti Agarwal]2 , [John Collomosse]1,2
20 | //!
21 | //! 1 DECaDE Centre for the Decentralized Digital Economy, University of Surrey, UK.\
22 | //! 2 Adobe Research, San Jose CA.
23 | //!
24 | //! ---
25 | //!
26 | //! This is a re-implementation of the [trustmark] Python library.
27 | //!
28 | //! [Tu Bui]: https://www.surrey.ac.uk/people/tu-bui
29 | //! [Shruti Agarwal]: https://research.adobe.com/person/shruti-agarwal/
30 | //! [John Collomosse]: https://www.collomosse.com/
31 | //! [trustmark]: https://pypi.org/project/trustmark/
32 | //!
33 | //! ## Example
34 | //!
35 | //! ```rust
36 | //! use trustmark::{Trustmark, Version, Variant};
37 | //!
38 | //! # fn main() {
39 | //! let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
40 | //! let input = image::open("../images/ghost.png").unwrap();
41 | //! let output = tm.encode("0010101".to_owned(), input, 0.95);
42 | //! # }
43 | //! ```
44 | use std::path::Path;
45 |
46 | use image::{DynamicImage, GenericImageView as _};
47 | use ort::{GraphOptimizationLevel, Session};
48 |
49 | use self::{bits::Bits, image_processing::ModelImage};
50 |
51 | mod bits;
52 | mod image_processing;
53 | mod model;
54 |
55 | /// A loaded Trustmark model.
56 | pub struct Trustmark {
57 | encoder: Session,
58 | decoder: Session,
59 | version: Version,
60 | variant: Variant,
61 | }
62 |
63 | #[derive(Debug, thiserror::Error)]
64 | pub enum Error {
65 | #[error("watermark is corrupt or missing")]
66 | CorruptWatermark,
67 | #[error("onnx error: {0}")]
68 | Ort(#[from] ort::Error),
69 | #[error("image processing error: {0}")]
70 | ImageProcessing(#[from] image_processing::Error),
71 | #[error("bits processing error: {0}")]
72 | Bits(bits::Error),
73 | #[error("invalid model variant")]
74 | InvalidModelVariant,
75 | }
76 |
77 | impl From for Error {
78 | fn from(value: bits::Error) -> Self {
79 | match value {
80 | bits::Error::CorruptWatermark => Error::CorruptWatermark,
81 | err => Error::Bits(err),
82 | }
83 | }
84 | }
85 |
86 | pub use bits::Version;
87 | pub use model::Variant;
88 |
89 | impl Trustmark {
90 | /// Load a Trustmark model.
91 | pub fn new>(
92 | models: P,
93 | variant: Variant,
94 | version: Version,
95 | ) -> Result {
96 | let encoder = Session::builder()?
97 | .with_optimization_level(GraphOptimizationLevel::Level3)?
98 | .with_intra_threads(8)?
99 | .commit_from_file(models.as_ref().join(variant.encoder_filename()))?;
100 | let decoder = Session::builder()?
101 | .with_optimization_level(GraphOptimizationLevel::Level3)?
102 | .with_intra_threads(8)?
103 | .commit_from_file(models.as_ref().join(variant.decoder_filename()))?;
104 | Ok(Self {
105 | encoder,
106 | decoder,
107 | version,
108 | variant,
109 | })
110 | }
111 |
112 | /// Encode a watermark into an image.
113 | ///
114 | /// `watermark` is a bitstring encoding the watermark identifier to encode. `img` is the image
115 | /// which will be watermarked. `strength` is a number between 0 and 1 indicating how strong the
116 | /// resulting watermark should be. 0.95 is a normal strength.
117 | pub fn encode(
118 | &self,
119 | watermark: String,
120 | img: DynamicImage,
121 | strength: f32,
122 | ) -> Result {
123 | let (original_width, original_height) = img.dimensions();
124 | let aspect_ratio = original_width as f32 / original_height as f32;
125 |
126 | // the image is always encoded with size 256x256
127 | let encode_size = 256;
128 |
129 | let input_img: ort::Value> =
130 | ModelImage(encode_size, self.variant, img.clone()).try_into()?;
131 | let bits: ort::Value> =
132 | Bits::apply_error_correction_and_schema(watermark, self.version)?.into();
133 | let outputs = self.encoder.run(ort::inputs![
134 | "onnx::Concat_0" => input_img,
135 | "onnx::Gemm_1" => bits,
136 | ]?)?;
137 | let output_img = outputs["image"].try_extract_tensor::()?.to_owned();
138 |
139 | // Need to calculate and apply the residual.
140 | let input_img: ort::Value> =
141 | ModelImage(encode_size, self.variant, img.clone()).try_into()?;
142 | let residual = (self.variant.strength_multiplier() * strength)
143 | * (output_img - input_img.try_extract_tensor::()?);
144 |
145 | // Residual should be small perturbations.
146 | let mut residual = residual.clamp(-0.2, 0.2);
147 | if (self.variant == Variant::Q && !(0.5..=2.0).contains(&aspect_ratio))
148 | || self.variant == Variant::P
149 | {
150 | residual = image_processing::remove_boundary_artifact(
151 | residual,
152 | (original_width as usize, original_height as usize),
153 | self.variant,
154 | );
155 | }
156 |
157 | let ModelImage(_, _, residual) = (encode_size, self.variant, residual).try_into()?;
158 |
159 | Ok(image_processing::apply_residual(img, residual))
160 | }
161 |
162 | /// Decode a watermark from an image.
163 | pub fn decode(&self, img: DynamicImage) -> Result {
164 | // P variant has a smaller decode size
165 | let decode_size = if self.variant == Variant::P { 224 } else { 256 };
166 |
167 | let img: ort::Value> =
168 | ModelImage(decode_size, self.variant, img).try_into()?;
169 | let outputs = self.decoder.run(ort::inputs![
170 | "image" => img,
171 | ]?)?;
172 | let watermark = outputs["output"].try_extract_tensor::()?.to_owned();
173 | let watermark: Bits = watermark.try_into()?;
174 | Ok(watermark.get_data())
175 | }
176 | }
177 |
178 | #[cfg(test)]
179 | mod tests {
180 | use super::*;
181 |
182 | #[test]
183 | fn loading_models() {
184 | Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
185 | }
186 |
187 | fn roundtrip(path: impl AsRef) {
188 | let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
189 | let input = image::open(path.as_ref()).unwrap();
190 | let watermark = "1011011110011000111111000000011111011111011100000110110110111".to_owned();
191 | let encoded = tm.encode(watermark.clone(), input, 0.95).unwrap();
192 | encoded.to_rgba8().save("./test.png").unwrap();
193 | let input = image::open("./test.png").unwrap();
194 | let decoded = tm.decode(input).unwrap();
195 | assert_eq!(watermark, decoded);
196 | }
197 |
198 | #[test]
199 | fn roundtrip_ghost() {
200 | roundtrip("../images/ghost.png");
201 | }
202 |
203 | #[test]
204 | fn roundtrip_ufo() {
205 | roundtrip("../images/ufo_240.jpg");
206 | }
207 | }
208 |
--------------------------------------------------------------------------------
/rust/src/model.rs:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Adobe
2 | // All Rights Reserved.
3 | //
4 | // NOTICE: Adobe permits you to use, modify, and distribute this file in
5 | // accordance with the terms of the Adobe license agreement accompanying
6 | // it.
7 |
8 | use std::{fmt::Display, str::FromStr};
9 |
10 | use crate::Error;
11 |
12 | /// The model variant to load.
13 | #[derive(Copy, Clone, Debug, PartialEq)]
14 | pub enum Variant {
15 | B,
16 | C,
17 | P,
18 | Q,
19 | }
20 |
21 | impl Variant {
22 | pub(super) fn encoder_filename(&self) -> String {
23 | let suffix = match self {
24 | Variant::B => "B",
25 | Variant::C => "C",
26 | Variant::P => "P",
27 | Variant::Q => "Q",
28 | };
29 |
30 | format!("encoder_{suffix}.onnx")
31 | }
32 |
33 | pub(super) fn decoder_filename(&self) -> String {
34 | let suffix = match self {
35 | Variant::B => "B",
36 | Variant::C => "C",
37 | Variant::P => "P",
38 | Variant::Q => "Q",
39 | };
40 |
41 | format!("decoder_{suffix}.onnx")
42 | }
43 |
44 | pub(super) fn strength_multiplier(&self) -> f32 {
45 | match self {
46 | Variant::P => 1.25,
47 | _ => 1.,
48 | }
49 | }
50 | }
51 |
52 | impl Display for Variant {
53 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 | let s = match self {
55 | Variant::B => "B",
56 | Variant::C => "C",
57 | Variant::P => "P",
58 | Variant::Q => "Q",
59 | };
60 |
61 | f.write_str(s)
62 | }
63 | }
64 |
65 | impl FromStr for Variant {
66 | type Err = Error;
67 |
68 | fn from_str(s: &str) -> Result {
69 | Ok(match s {
70 | "B" => Variant::B,
71 | "C" => Variant::C,
72 | "P" => Variant::P,
73 | "Q" => Variant::Q,
74 | _ => return Err(Error::InvalidModelVariant),
75 | })
76 | }
77 | }
78 |
--------------------------------------------------------------------------------