蛋白质序列FeatureDict转化为TensorDict

主要转化语句为 tensor_dict = {k: tf.constant(v) for k, v in np_example.items() if k in features_metadata}。 增加了特征名称的选择,不同特征维度,特征数的判断等。

from typing import Dict, Tuple, Sequence, Union, Mapping, Optional
#import tensorflow.compat.v1 as tf
import tensorflow as tf
import numpy as np
import pickle

# Type aliases.
FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]

#FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, tf.Tensor]

NUM_RES = 'num residues placeholder'
NUM_TEMPLATES = 'num templates placeholder'
NUM_SEQ = "length msa placeholder"

atom_type_num = 37

FEATURES = {
    #### Static features of a protein sequence ####
    "aatype": (tf.float32, [NUM_RES, 21]),
    "between_segment_residues": (tf.int64, [NUM_RES, 1]),
    "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),
    "domain_name": (tf.string, [1]),
    "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),
    "num_alignments": (tf.int64, [NUM_RES, 1]),
    "residue_index": (tf.int64, [NUM_RES, 1]),
    "seq_length": (tf.int64, [NUM_RES, 1]),
    "sequence": (tf.string, [1]),
    "all_atom_positions": (tf.float32,
                           [NUM_RES, atom_type_num, 3]),
    "all_atom_mask": (tf.int64, [NUM_RES, atom_type_num]),
    "resolution": (tf.float32, [1]),
    "template_domain_names": (tf.string, [NUM_TEMPLATES]),
    "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),
    "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),
    "template_all_atom_positions": (tf.float32, [
        NUM_TEMPLATES, NUM_RES, atom_type_num, 3
    ]),
    "template_all_atom_masks": (tf.float32, [
        NUM_TEMPLATES, NUM_RES, atom_type_num, 1
    ]),
}


def _make_features_metadata(
    feature_names: Sequence[str]) -> FeaturesMetadata:
  """Makes a feature name to type and shape mapping from a list of names."""
  # Make sure these features are always read.
  required_features = ["aatype", "sequence", "seq_length"]
  feature_names = list(set(feature_names) | set(required_features))
  features_metadata = {name: FEATURES[name] for name in feature_names}
  return features_metadata


def np_to_tensor_dict(
    np_example: Mapping[str, np.ndarray],
    features: Sequence[str],
    ) -> TensorDict:
  """Creates dict of tensors from a dict of NumPy arrays.

  Args:
    np_example: A dict of NumPy feature arrays.
    features: A list of strings of feature names to be returned in the dataset.

  Returns:
    A dictionary of features mapping feature names to features. Only the given
    features are returned, all other ones are filtered out.
  """
  features_metadata = _make_features_metadata(features)
    
  print(f"features_metadata:{features_metadata}")
  
  tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
                 if k in features_metadata}

  
  #print(f"tensor_dict:{tensor_dict}")
    
    
  # Ensures shapes are as expected. Needed for setting size of empty features
  # e.g. when no template hits were found.
  tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
  return tensor_dict


def protein_features_shape(feature_name: str,
                           num_residues: int,
                           msa_length: int,
                           num_templates: Optional[int] = None,
                           features: Optional[FeaturesMetadata] = None):
  """Get the shape for the given feature name.

  This is near identical to _get_tf_shape_no_placeholders() but with 2
  differences:
  * This method does not calculate a single placeholder from the total number of
    elements (eg given <NUM_RES, 3> and size := 12, this won't deduce NUM_RES
    must be 4)
  * This method will work with tensors

  Args:
    feature_name: String identifier for the feature. If the feature name ends
      with "_unnormalized", this suffix is stripped off.
    num_residues: The number of residues in the current domain - some elements
      of the shape can be dynamic and will be replaced by this value.
    msa_length: The number of sequences in the multiple sequence alignment, some
      elements of the shape can be dynamic and will be replaced by this value.
      If the number of alignments is unknown / not read, please pass None for
      msa_length.
    num_templates (optional): The number of templates in this tfexample.
    features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.

  Returns:
    List of ints representation the tensor size.

  Raises:
    ValueError: If a feature is requested but no concrete placeholder value is
        given.
  """
  features = features or FEATURES
  if feature_name.endswith("_unnormalized"):
    feature_name = feature_name[:-13]
  
  # features是FeaturesMetadata数据结构
  # FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]
  unused_dtype, raw_sizes = features[feature_name]

  #print(f"feature_name:{feature_name}")
  #print(f"features value:{features[feature_name]}") 
  #print(f"features[feature_name]:{features[feature_name]}")
  #print(f"unused_dtype:{unused_dtype}")
  #print(f"raw_sizes:{raw_sizes}"

  replacements = {NUM_RES: num_residues,
                  NUM_SEQ: msa_length}

  if num_templates is not None:
    replacements[NUM_TEMPLATES] = num_templates

  # my_dict.get(key, default_value)
  sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
  for dimension in sizes:
    if isinstance(dimension, str):
      raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
          feature_name, raw_sizes, replacements))
  return sizes


def parse_reshape_logic(
    parsed_features: TensorDict,
    features: FeaturesMetadata,
    key: Optional[str] = None) -> TensorDict:
  """Transforms parsed serial features to the correct shape."""
  # Find out what is the number of sequences and the number of alignments.
  num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)

  if "num_alignments" in parsed_features:
    num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
  else:
    num_msa = 0

  if "template_domain_names" in parsed_features:
    num_templates = tf.cast(
        tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
  else:
    num_templates = 0

  if key is not None and "key" in features:
    parsed_features["key"] = [key]  # Expand dims from () to (1,).

  # Reshape the tensors according to the sequence length and num alignments.
  for k, v in parsed_features.items():
    new_shape = protein_features_shape(
        feature_name=k,
        num_residues=num_residues,
        msa_length=num_msa,
        num_templates=num_templates,
        features=features)
    
    #print(f"new_shape:{new_shape}")
    
    new_shape_size = tf.constant(1, dtype=tf.int32)
    
    
    for dim in new_shape:
      new_shape_size *= tf.cast(dim, tf.int32)

    #print(f"new_shape_size:{new_shape_size}")
    #print(f"original_shape_size:{ tf.size(v)}")
    
    # 断言函数,用于检查两个张量是否相等。不相等引发异常
    assert_equal = tf.assert_equal(
        tf.size(v), new_shape_size,
        name="assert_%s_shape_correct" % k,
        message="The size of feature %s (%s) could not be reshaped "
        "into %s" % (k, tf.size(v), new_shape))
    
    if "template" not in k:
      # Make sure the feature we are reshaping is not empty.
      assert_non_empty = tf.assert_greater(
          tf.size(v), 0, name="assert_%s_non_empty" % k,
          message="The feature %s is not set in the tf.Example. Either do not "
          "request the feature or use a tf.Example that has the "
          "feature set." % k)
      with tf.control_dependencies([assert_non_empty, assert_equal]):
        parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
    else:
      with tf.control_dependencies([assert_equal]):
        parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)

  return parsed_features


def _first(tensor: tf.Tensor) -> tf.Tensor:
  """Returns the 1st element - the input can be a tensor or a scalar."""
  return tf.reshape(tensor, shape=(-1,))[0] # 将其转换为一维数组


## 读入FeatureDict列表
with open("HBB_features_lst.pkl", 'rb') as f:
  HBB_features_lst = pickle.load(f)

Human_HBB_feature_dict = HBB_features_lst[0]

print(Human_HBB_feature_dict.keys())

#print(Human_HBB_feature_dict['num_alignments'])
  
features = FEATURES.keys()

#for key in Human_HBB_feature_dict.keys():
#    if key not in features:
#        print(key)

#print(features)

Human_HBB_tensor_dict = np_to_tensor_dict(Human_HBB_feature_dict,features= features)

print(Human_HBB_tensor_dict.keys())
#print(Human_HBB_tensor_dict)

#print(Human_HBB_tensor_dict["template_domain_names"])

相关推荐

  1. 蛋白质序列FeatureDict转化TensorDict

    2023-12-09 16:16:03       33 阅读
  2. ip地址怎么转化十进制

    2023-12-09 16:16:03       37 阅读
  3. R-列表、矩阵、数组转化向量

    2023-12-09 16:16:03       43 阅读
  4. PHP将HTML标签转化图片

    2023-12-09 16:16:03       19 阅读
  5. react类组件转化函数组件

    2023-12-09 16:16:03       12 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-09 16:16:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-09 16:16:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-09 16:16:03       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-09 16:16:03       20 阅读

热门阅读

  1. BFC(Block Formatting Contexts)块级格式化上下文

    2023-12-09 16:16:03       33 阅读
  2. SQL 语法

    2023-12-09 16:16:03       32 阅读
  3. C++使用模板的注意事项

    2023-12-09 16:16:03       29 阅读
  4. 比较不同聚类方法的评估指标

    2023-12-09 16:16:03       36 阅读
  5. SpringBoot基础系列:工具类使用

    2023-12-09 16:16:03       36 阅读
  6. C语言的关键字大全

    2023-12-09 16:16:03       37 阅读
  7. Android TextView 超出省略失效 解决方法

    2023-12-09 16:16:03       37 阅读
  8. Linux cron定时任务常用方法

    2023-12-09 16:16:03       34 阅读