Giter VIP home page Giter VIP logo

openff-models's Introduction

openff-models

GitHub Actions Build Status codecov

Helper classes for Pydantic compatibility in the OpenFF stack

Getting started

import pprint
import json

from openff.models.models import DefaultModel
from openff.models.types import ArrayQuantity, FloatQuantity
from openff.units import unit


class Atom(DefaultModel):
    mass: FloatQuantity["atomic_mass_constant"]
    charge: FloatQuantity["elementary_charge"]
    some_array: ArrayQuantity["nanometer"]


atom = Atom(
    mass=12.011 * unit.atomic_mass_constant,
    charge=0.0 * unit.elementary_charge,
    some_array=unit.Quantity([4, -1, 0], unit.nanometer),
)

print(atom.dict())
# {'mass': <Quantity(12.011, 'atomic_mass_constant')>, 'charge': <Quantity(0.0, 'elementary_charge')>, 'some_array': <Quantity([ 4 -1  0], 'nanometer')>}

# Note that unit-bearing fields use custom serialization into a dict with separate key-val pairs for
# the unit (as a string) and unitless quantities (in whatever shape the data is)
print(atom.json())
# {"mass": "{\"val\": 12.011, \"unit\": \"atomic_mass_constant\"}", "charge": "{\"val\": 0.0, \"unit\": \"elementary_charge\"}", "some_array": "{\"val\": [4, -1, 0], \"unit\": \"nanometer\"}"}

# The same thing, just more human-readable
pprint.pprint(json.loads(atom.json()))
# {'charge': '{"val": 0.0, "unit": "elementary_charge"}',
#  'mass': '{"val": 12.011, "unit": "atomic_mass_constant"}',
#  'some_array': '{"val": [4, -1, 0], "unit": "nanometer"}'}

# Can also roundtrip through these representations
assert Atom(**atom.dict()).charge.m == 0.0
assert Atom.parse_raw(atom.json()).charge.m == 0.0

Currently, models can also be defined with a simple unit.Quantity annotation. This keeps serialization functionality but does not pick up the validaiton features of the custom types, i.e. dimensionality validation.

import json

from openff.units import unit
from openff.models.models import DefaultModel


class Atom(DefaultModel):
    mass: unit.Quantity = unit.Quantity(0.0, unit.amu)

json.loads(Atom(mass=12.011 * unit.atomic_mass_constant).json())
# {'mass': '{"val": 12.011, "unit": "atomic_mass_constant"}'}

# This model does have instructions to keep masses in mass units
json.loads(Atom(mass=12.011 * unit.nanometer).json())
# {'mass': '{"val": 12.011, "unit": "nanometer"}'}

Copyright

Copyright (c) 2022, Open Force Field Initiative

Acknowledgements

Project based on the Computational Molecular Science Python Cookiecutter version 1.6.

openff-models's People

Contributors

dependabot[bot] avatar mattwthompson avatar mikemhenry avatar pre-commit-ci[bot] avatar richardjgowers avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar

openff-models's Issues

Support arrays with type checking

Thanks for starting this up @mattwthompson, looking forward to just using openff.models in everything!

It would be nice to be able to specify array typing too (and possibly shape checking). I implemented some array dtyping once, but never got round to doing shapes.

from typing import Any

import numpy as np


class ArrayMeta(type):
    def __getitem__(cls, T):
        return type("Array", (Array,), {"__dtype__": T})


class Array(np.ndarray, unit.Quantity, metaclass=ArrayMeta):
    """A typeable numpy array"""

    @classmethod
    def __get_validators__(cls):
        yield cls.validate_type

    @classmethod
    def validate_type(cls, val):
        from openff.units import unit
        from openff.units.units import Unit
        
        dtype = getattr(cls, "__dtype__", Any)
        if dtype is Any:
            dtype = None
            
        if isinstance(dtype, Unit):
            # assign units
            val = unit.Quantity(val, dtype)
            # coerce into np.ndarray
            val = unit.Quantity.from_list(val)
            return val
        return np.asanyarray(val, dtype=dtype)

In practice, this looks like:

In:

class Model(BaseModel):
    a: Array[unit.kelvin]
    b: Array[float]
    c: Array[int]
    
    class Config:
        arbitrary_types_allowed = True
        validate_assignment = True
    

In:

int_array = np.arange(3).astype(int)
x = Model(a=int_array, b=int_array, c=int_array)
print(x.a)
print(x.b)
print(x.c)

Out:

[0.0 1.0 2.0] kelvin
[0. 1. 2.]
[0 1 2]

And an error:

In:

x.a = 3 * unit.kelvin

Out:

ValidationError                           Traceback (most recent call last)
Input In [130], in <cell line: 1>()
----> 1 x.a = 3 * unit.kelvin

File ~/anaconda3/envs/gnn-charge-models-test/lib/python3.9/site-packages/pydantic/main.py:380, in pydantic.main.BaseModel.__setattr__()

ValidationError: 1 validation error for Model
a
  object of type 'int' has no len() (type=type_error)

Currently it strips types if there are any.

x.b = 3 * unit.kelvin
/var/folders/rv/j6lbln6j0kvb5svxj8wflc400000gn/T/ipykernel_18158/1139253006.py:32: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  return np.asanyarray(val, dtype=dtype)

And type checking still runs:

x.a = int_array * unit.m
ValidationError: 1 validation error for Model
a
  Cannot convert from 'meter' ([length]) to 'kelvin' ([temperature]) (type=type_error.dimensionality; units1=meter; units2=kelvin; dim1=[length]; dim2=[temperature]; extra_msg=)

`ArrayQuantity.validate_type` chokes on `list[openmm.unit.Quantity[openmm.Vec3]]`

Context openforcefield/openff-interchange#879

from openff.models.types import ArrayQuantity, MissingUnitError
from openff.units.openmm import ensure_quantity
from openff.toolkit import Molecule, ForceField, Quantity, unit
import numpy
import openmm.unit
import openmm
from openff.interchange import Interchange

topology = Molecule.from_smiles("CCO").to_topology()
topology.box_vectors = Quantity(4 * numpy.eye(3), unit.nanometer)

# this is a 3-length list of openmm.unit.Quantity objects that themselves wrap openmm.Vec3 objects
openmm_box = ForceField("openff-2.1.0.offxml").create_openmm_system(topology).getDefaultPeriodicBoxVectors()

assert isinstance(openmm_box, list)
assert isinstance(openmm_box[-1], openmm.unit.Quantity)
assert isinstance(openmm_box[-1]._value, openmm.Vec3)

try:
    Interchange.validate_box(openmm_box)
except MissingUnitError as error:
    error1 = error

try:
    ArrayQuantity.validate_type(openmm_box)
except MissingUnitError as error:
    error2 = error

assert str(error1) == str(error2)

Support non-FloatQuantity and non-ArrayQuantity types

I expected this to work:

from openff.models.models import DefaultModel
from openff.models.types import ArrayQuantity, FloatQuantity
from openff.units import unit


class Atom(DefaultModel):
    mass: FloatQuantity["atomic_mass_constant"]
    charge: FloatQuantity["elementary_charge"]
    some_array: ArrayQuantity["nanometer"]
    baz: int


atom = Atom(
    mass=12.011 * unit.atomic_mass_constant,
    charge=0.0 * unit.elementary_charge,
    some_array=unit.Quantity([4, -1, 0], unit.nanometer),
    baz = 2,
)

Atom.parse_raw(atom.json())

But if fails on the round trip since it can't parse baz properly:

{'baz': 2,
 'charge': '{"val": 0.0, "unit": "elementary_charge"}',
 'mass': '{"val": 12.011, "unit": "atomic_mass_constant"}',
 'some_array': '{"val": [4, -1, 0], "unit": "nanometer"}'}
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/miniconda3/envs/gufe/lib/python3.10/site-packages/pydantic/main.py:534, in pydantic.main.BaseModel.parse_raw()

File ~/miniconda3/envs/gufe/lib/python3.10/site-packages/pydantic/parse.py:37, in pydantic.parse.load_str_bytes()

File ~/miniconda3/envs/gufe/lib/python3.10/site-packages/openff/models/types.py:135, in json_loader(data)
    132 try:
    133     # Directly look for an encoded FloatQuantity/ArrayQuantity,
    134     # which is itself a dict
--> 135     v = json.loads(val)
    136 except json.JSONDecodeError:
    137     # Handles some cases of the val being a primitive type

File ~/miniconda3/envs/gufe/lib/python3.10/json/__init__.py:339, in loads(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
    338 if not isinstance(s, (bytes, bytearray)):
--> 339     raise TypeError(f'the JSON object must be str, bytes or bytearray, '
    340                     f'not {s.__class__.__name__}')
    341 s = s.decode(detect_encoding(s), 'surrogatepass')

TypeError: the JSON object must be str, bytes or bytearray, not int

During handling of the above exception, another exception occurred:

ValidationError                           Traceback (most recent call last)
Input In [4], in <cell line: 20>()
     10     baz: int
     13 atom = Atom(
     14     mass=12.011 * unit.atomic_mass_constant,
     15     charge=0.0 * unit.elementary_charge,
     16     some_array=unit.Quantity([4, -1, 0], unit.nanometer),
     17     baz = 2,
     18 )
---> 20 Atom.parse_raw(atom.json())

File ~/miniconda3/envs/gufe/lib/python3.10/site-packages/pydantic/main.py:543, in pydantic.main.BaseModel.parse_raw()

ValidationError: 1 validation error for Atom
__root__
  the JSON object must be str, bytes or bytearray, not int (type=type_error)

This bit https://github.com/mattwthompson/openff-models/blob/61260ec1d2746a0431671498dfe1540b006354db/openff/models/types.py#L137
doesn't actually seem to handle the case where the type isn't a float or array quantity.

`__version__` not set

git checkout 0.0.0 && python -m pip install -e . && python -c "from openff.models import __version__; print(__version__)"                                        10:14:22  ☁  90588ac ☀
Previous HEAD position was 90588ac Initial commit after CMS Cookiecutter creation, version 1.6
HEAD is now at 97f7327 Merge pull request #10 from mattwthompson/dependabot/github_actions/actions/checkout-3
Obtaining file:///Users/mattthompson/software/openff-models
  Preparing metadata (setup.py) ... done
Installing collected packages: openff-models
  Running setup.py develop for openff-models
Successfully installed openff-models-0.0.0
Traceback (most recent call last):
  File "<string>", line 1, in <module>
ImportError: cannot import name '__version__' from 'openff.models' (/Users/mattthompson/software/openff-models/openff/models/__init__.py)

Spurious warning in some cases of processing OpenMM objects

With 0.1.2:

In [9]: system = openmm.XmlSerializer.deserialize(open("system.xml").read())

In [10]: ArrayQuantity.validate_type(system.getDefaultPeriodicBoxVectors())
/Users/mattthompson/micromamba/envs/openff-interchange-dev/lib/python3.11/site-packages/pint/compat.py:60: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  return np.asarray(value)
Out[10]:
array([[2.5, 0. , 0. ],
       [0. , 2.5, 0. ],
       [0. , 0. , 2.5]]) <Unit('nanometer')>

Why not use openff.units unit's in the BaseModel

There's probably a good technical reason, but is there a reason we can't hack either (this package's) BaseModel or openff-units to allow this to work:

In [12]: class Thing(BaseModel):
    ...:     temp: units.unit.kelvin
    ...:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/validators.py:751, in pydantic.validators.find_validators()

TypeError: issubclass() arg 1 must be a class

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In [12], line 1
----> 1 class Thing(BaseModel):
      2     temp: units.unit.kelvin

File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/main.py:198, in pydantic.main.ModelMetaclass.__new__()

File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/fields.py:506, in pydantic.fields.ModelField.infer()

File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/fields.py:436, in pydantic.fields.ModelField.__init__()

File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/fields.py:557, in pydantic.fields.ModelField.prepare()

File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/fields.py:831, in pydantic.fields.ModelField.populate_validators()

File ~/miniconda3/envs/openfe/lib/python3.10/site-packages/pydantic/validators.py:760, in find_validators()

RuntimeError: error checking inheritance of <Unit('kelvin')> (type: Unit)

Validation issue?

While debugging an issue, I noticed that it doesn't look like validation is happening when a user passes in a value without units, instead the value is coerced to the unit.

from openff.models.models import DefaultModel
from openff.models.types import FloatQuantity
from openff.units import unit

class Atom(DefaultModel):
    mass: FloatQuantity["atomic_mass_constant"] 
    
# This should error but works
atom=Atom(mass=12)
print(atom)
# This works as expected
atom_units=Atom(mass=12*unit.amu)
print(atom_units)
# This fails as expected, so wrong units are checked, but not no units/missing units
atom_units_wrong=Atom(mass=12*unit.kelvin)
print(atom_units_wrong)

I expected to hit this code branch https://github.com/mattwthompson/openff-models/blob/main/openff/models/types.py#L39-L43

Instead we are hitting this bit https://github.com/mattwthompson/openff-models/blob/main/openff/models/types.py#L64-L65 which just adds the units to the value.

Support dimensionality checking

It would be nice to support something like:

class Atom(DefaultModel):
    mass: FloatQuantity["mass"]

And then we would accept any unit that has dimensionality "mass".
Same would go for "length" so then we would accept "nm", "angstrom", etc

something like hgrecco/pint#1166 (comment)

compound FloatQuantity

I'm trying to do:

class T(models.BaseModel):
    foo: types.FloatQuantity['unit.boltzmann_constant * unit.kelvin']

T(foo=2.5 * unit.boltzmann_constant * unit.kelvin)

and I get:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[37], line 1
----> 1 T2(foo=2.5 * unit.boltzmann_constant * unit.kelvin)

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pydantic/main.py:339, in pydantic.main.BaseModel.__init__()

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pydantic/main.py:1076, in pydantic.main.validate_model()

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pydantic/fields.py:884, in pydantic.fields.ModelField.validate()

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pydantic/fields.py:1101, in pydantic.fields.ModelField._validate_singleton()

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pydantic/fields.py:1151, in pydantic.fields.ModelField._apply_validators()

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pydantic/class_validators.py:337, in pydantic.class_validators._generic_validator_basic.lambda13()

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/openff/models/types.py:54, in FloatQuantity.validate_type(cls, val)
     50         raise UnitValidationError(
     51             f"Could not validate data of type {type(val)}"
     52         )
     53 else:
---> 54     unit_ = unit(unit_)
     55     if isinstance(val, unit.Quantity):
     56         # some custom behavior could go here
     57         assert unit_.dimensionality == val.dimensionality

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pint/facets/plain/registry.py:1252, in PlainRegistry.parse_expression(self, input_string, case_sensitive, use_decimal, **values)
   1250 for p in self.preprocessors:
   1251     input_string = p(input_string)
-> 1252 input_string = string_preprocessor(input_string)
   1253 gen = tokenizer(input_string)
   1255 return build_eval_tree(gen).evaluate(
   1256     lambda x: self._eval_token(x, case_sensitive=case_sensitive, **values)
   1257 )

File ~/miniconda3/envs/openfe_7/lib/python3.10/site-packages/pint/util.py:780, in string_preprocessor(input_string)
    779 def string_preprocessor(input_string: str) -> str:
--> 780     input_string = input_string.replace(",", "")
    781     input_string = input_string.replace(" per ", "/")
    783     for a, b in _subs_re:

AttributeError: 'Unit' object has no attribute 'replace'

Is this something that should work (eventually if I fixed it) or am I using FloatQuantity wrong here?

Support some convince types

I'm exactly sure what these are called, but it would be useful to have:

PositiveFloatQuantity
NegitiveFloatQuantity
OptionalFloatQuantity

It would be nice to use OptionalFloatQuantity to denote a field that is optional (as of now, I wasn't able to hack together something that would work). Something like PositiveFloatQuantity would be useful when defining a model where only positive values makes sense (like pydantic has with PositiveFloat).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.