Source code for EMToolKit.schemas.basic_schemas

import os, sys, time, pickle, resource, copy, numpy as np
# EMToolKit/schemas/basic_schemas.py
from . import operators, element_functions, element_grid, coord_grid, element_source_responses
# from .schemas import operators, element_functions, element_grid, coord_grid, element_source_responses
from EMToolKit.schemas.operators import multi_instrument_linear_operator, sparse_d2_partial_matrix
from EMToolKit.schemas.operators import reg_operator_postfac_wrapper, single_instrument_linear_operator_separable
from EMToolKit.schemas.basic_transforms import generic_transform, trivialframe, fits_transform
from EMToolKit.schemas.element_functions import (nd_voigt_psf, bin_function, get_2d_cov, get_3d_cov,
                               nd_gaussian_psf, nd_powgaussian_psf, spike_function,
                               flattop_guassian_psf, spice_spectrograph_psf)
from EMToolKit.schemas.element_grid import detector_grid, source_grid
from EMToolKit.schemas.coord_grid import coord_grid
from EMToolKit.schemas.element_source_responses import element_source_responses as esr
from sunpy.map import Map

# A detector schema needs to contain the information needed to map from a given source
# (or at least a source as defined by basic_source) onto its own detector numbers,
# as well as the code to compute the mapping.
[docs] class basic_detector(object): def __init__(self,meta): self.meta = meta self.wcs = None # Not yet implemented self.ndim = 2 # We're not very consistent here about how many axes the source has. In some places we use # more general code. However, it elsewhere assumes we have a separable plane-of-sky+temperature # response style of detector, like AIA or XRT. self.transform = meta.get('TRANSFORM', fits_transform(self.meta)) self.shape = meta['SHAPE'] self.scale = np.array([np.sum((self.transform.index2coord(1,0)-self.transform.index2coord(0,0))**2)**0.5, np.sum((self.transform.index2coord(1,0)-self.transform.index2coord(0,0))**2)**0.5]) #self.scale = np.array([meta.get('cdelt'+str(i)) for i in range(1,meta.get('naxis')+1)]) #self.crpix = np.array([meta.get('crpix'+str(i)) for i in range(1,meta.get('naxis')+1)]) #self.crval = np.array([meta.get('crval'+str(i)) for i in range(1,meta.get('naxis')+1)]) self.origin = self.transform.index2coord(0.0,0.0) self.psfangle = meta.get('PSFANGLE',0.0) self.psfsigmas = np.array([meta.get('PSFSZ1',0.5),meta.get('PSFSZ2',0.5)])*self.scale self.psfcov = get_2d_cov(self.psfsigmas,self.psfangle) self.ipsfcov = np.linalg.inv(self.psfcov) # This assumes that the transform is affine: self.fwdtransform = np.matmul(self.transform.tfmat,np.diag(self.scale)) self.frame = trivialframe(np.arange(1,self.ndim+1).astype(str)) self.coords = coord_grid(self.shape,self.origin,self.fwdtransform,self.frame) self.grid = detector_grid(self.coords, [self.ipsfcov], nd_gaussian_psf) self.logt, self.tresp = meta['LOGT'], meta['TRESP'] self.transform = meta.get('TRANSFORM',generic_transform) self.fwdops = []
[docs] def fwdop(self,source): index = len(self.fwdops) for index in range(0,len(self.fwdops)): if(source.is_same(self.fwdops[index]['SOURCE'])): break if(index==len(self.fwdops)): sresp = esr(source.grid,self.grid,self.transform) sresp *= np.prod(source.scale)/np.prod(self.scale)/np.median(np.sum(sresp,axis=0).A1) operator = single_instrument_linear_operator_separable(sresp, self.tresp, temps=self.logt, exptime=self.meta['EXPTIME']) self.fwdops.append({'OPERATOR':operator,'SOURCE':source}) return self.fwdops[index]['OPERATOR']
[docs] class basic_source(object): # This is a boneheaded source model that's based on an # ndcube data sequence (really anything that behaves like # a sunpy map should do. It assumes the map # has meta with crpix, cdelt, crval, and crota keywords # and does a basic transform assuming a fixed observer # coordinate. This should be updated to use the astropy WCS!! # Ideally the forward operation should be part of what the # data object does, but perhaps we don't want to load that # on the data objects since they might not be expected # to work on the multi-instrument DEM paradigm? # This is also a plane-of-sky+DEM model, and the modelheader # cdelt, etc parameters are written as if it's a simple image def __init__(self,sequence,super_fac=1, logt=None): nc = len(sequence) def minmax(arg): return([np.min(arg),np.max(arg)]) date = None # Need to find out the smallest pixel size in the sequence and the maximum extent of the data in it for i in range(0,nc): logt = sequence[i].meta['LOGT'] tmini,tmaxi = minmax(logt) dti = np.min(np.abs(logt[1:]-logt[0:-1])) nxi, nyi = sequence[i].data.shape bft = fits_transform(sequence[i].meta) corner00, corner01 = bft.index2coord(0,0), bft.index2coord(0,nyi) corner10, corner11 = bft.index2coord(nxi,0), bft.index2coord(nxi,nyi) xmini,xmaxi = minmax([corner00[0],corner01[0],corner10[0],corner11[0]]) ymini,ymaxi = minmax([corner00[1],corner01[1],corner10[1],corner11[1]]) dxi, dyi = np.sum((bft.index2coord(1,0)-corner00)**2)**0.5, np.sum((bft.index2coord(0,1)-corner00)**2)**0.5 if(i>0): dx, dy, dt = [np.min([dxi,dx]),np.min([dyi,dy]),np.min([dti,dt])] tmin,xmin,ymin = [np.min([tmin,tmini]),np.min([xmin,xmini]),np.min([ymin,ymini])] tmax,xmax,ymax = [np.max([tmax,tmaxi]),np.max([xmax,xmaxi]),np.max([ymax,ymaxi])] else: [dx,dy,tmin,tmax,xmin,xmax,ymin,ymax,dt] = [dxi,dyi,tmini,tmaxi,xmini,xmaxi,ymini,ymaxi,dti] if(date is None): date=sequence[i].meta.get('DATE-OBS',sequence[i].meta.get('DATE-AVG')) dx, dy = dx*super_fac, dy*super_fac nx,ny = np.ceil((xmax-xmin)/dx).astype(np.uint32)+1, np.ceil((ymax-ymin)/dy).astype(np.uint32)+1 check = True if(logt is None): if(nc == 1): logt = sequence[0].meta['LOGT'] else: # Figure out if all of the logts are the same for i in range(1,nc): check *= len(sequence[i].meta['LOGT']) == len(sequence[i-1].meta['LOGT']) if(check): check *= np.sum((sequence[i].meta['LOGT']-sequence[i-1].meta['LOGT'])**2) == 0 if(check): logt, nt = sequence[0].meta['LOGT'], len(sequence[0].meta['LOGT']) else: nt = np.round((tmax-tmin)/dt).astype(np.uint32)+1 logt = tmin+dt*np.arange(nt) self.logt = logt basislogt = np.linspace(np.min(logt),np.max(logt),2*(nt-1)+1) [self.logts,self.bases] = [np.tile(basislogt,[nt,1]),np.zeros([nt,len(basislogt)])] for i in range(0,nt): self.bases[i] = (basislogt == logt[i]) if(i > 0): self.bases[i] += (basislogt-logt[i-1])*(basislogt < logt[i])*(basislogt > logt[i-1])/(logt[i]-logt[i-1]) if(i < nt-1): self.bases[i] += (logt[i+1]-basislogt)*(basislogt < logt[i+1])*(basislogt > logt[i])/(logt[i+1]-logt[i]) x0, y0 = xmin - 0.5*((nx-1)*dx-(xmax-xmin)), ymin - 0.5*((ny-1)*dy-(ymax-ymin)) self.meta = {'CDELT1':dx, 'CDELT2':dy, 'CROTA':0.0, 'CRPIX1':1, 'CRPIX2':1, 'RSUN_REF':sequence[0].meta['RSUN_REF'], 'CRVAL1':x0, 'CRVAL2':y0, 'NAXIS1':nx, 'NAXIS2':ny, 'DSUN_OBS':sequence[0].meta['DSUN_OBS'], 'CUNIT1':sequence[0].meta['CUNIT1'], 'CUNIT2':sequence[0].meta['CUNIT2'], 'HGLN_OBS':sequence[0].meta['HGLN_OBS'], 'ctype1':sequence[0].meta['ctype1'], 'CTYPE2':sequence[0].meta['CTYPE2'], 'HGLT_OBS':sequence[0].meta['HGLT_OBS'], 'LOGT':logt, 'PARENT_WCS':sequence[0].wcs, 'DATE-OBS':date, 'DATE-AVG':date} self.shape, self.axes = [nt,nx,ny], [logt,x0+dx*np.arange(nx),y0+dy*np.arange(ny)] self.scale = np.array([dx,dy]) self.spatial_shape = np.array([nx,ny]) dummymap = Map(np.zeros(self.spatial_shape,dtype=bool),self.meta) self.wcs = dummymap.wcs # This is wrong and a placeholder, since the model may have different dimensions and scale. Need to build a WCS instead. self.transform = fits_transform(self.meta) self.origin = self.transform.index2coord(0.0,0.0) self.ndim_spatial = len(self.origin) self.fwdtransform = np.matmul(self.transform.tfmat,np.diag(self.scale)) self.frame = trivialframe(np.arange(1,self.ndim_spatial+1).astype(str)) self.coords = coord_grid(self.spatial_shape,self.origin,self.fwdtransform,self.frame) self.grid = source_grid(self.coords,None,bin_function)
[docs] def is_same(self,src): check = 'meta' in dir(src) if(check): check = ((self.meta['CDELT1'] == src.meta.get('CDELT1'))* (self.meta['CDELT2'] == src.meta.get('CDELT2'))* (self.meta['CROTA'] == src.meta.get('CROTA'))* (self.meta['CRPIX1'] == src.meta.get('CRPIX1'))* (self.meta['CRPIX2'] == src.meta.get('CRPIX2'))* (self.meta['CRVAL1'] == src.meta.get('CRVAL1'))* (self.meta['CRVAL2'] == src.meta.get('CRVAL2'))* (self.meta['NAXIS1'] == src.meta.get('NAXIS1'))* (self.meta['NAXIS2'] == src.meta.get('NAXIS2'))* (len(self.meta['LOGT']) == len(src.meta.get('LOGT',[])))) return check